From 172d1de660be3e67e4a57c992c2e9edddcf32bbf Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Mon, 30 Jan 2023 15:37:21 +0800 Subject: [PATCH] Support stream priority for standalone executor (#49939) * Support stream priority for standalone executor * Fix compile error * Fix compile error * Fix compile error * Fix compile error * Fix compile error --- .../distributed/auto_parallel/dist_attr.cc | 9 +++ .../distributed/auto_parallel/dist_attr.h | 7 ++ .../fast_threaded_ssa_graph_executor.cc | 3 +- .../details/threaded_ssa_graph_executor.cc | 3 +- .../interpreter/interpreter_util.cc | 33 +++++---- .../interpreter/stream_analyzer.cc | 20 +++-- .../framework/new_executor/interpretercore.cc | 16 ++-- .../framework/new_executor/interpretercore.h | 8 +- .../new_executor/new_executor_defs.h | 11 ++- paddle/fluid/platform/device_context.cc | 73 ++++++++++++++----- paddle/fluid/platform/device_context.h | 3 +- paddle/fluid/pybind/auto_parallel_py.cc | 3 + paddle/fluid/pybind/cuda_streams_py.cc | 25 +++---- paddle/phi/backends/gpu/cuda/cuda_info.cc | 7 ++ paddle/phi/backends/gpu/gpu_context.cc | 12 +-- paddle/phi/backends/gpu/gpu_context.h | 6 +- paddle/phi/backends/gpu/gpu_info.h | 3 + paddle/phi/backends/gpu/rocm/rocm_info.cc | 7 ++ paddle/phi/core/cuda_stream.h | 50 +++++++------ .../test_standalone_custom_stream.py | 2 + 20 files changed, 206 insertions(+), 95 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index 044c3819797..06754fd6b16 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -293,6 +293,7 @@ std::vector OperatorDistAttr::fields_{"process_mesh", "impl_idx", "is_recompute", "execution_stream", + "stream_priority", "scheduling_priority"}; OperatorDistAttr::OperatorDistAttr(const OpDesc& op) { @@ -318,6 +319,8 @@ OperatorDistAttr& OperatorDistAttr::operator=( std::swap(this->impl_idx_, tmp.impl_idx_); std::swap(this->is_recompute_, tmp.is_recompute_); std::swap(this->execution_stream_, tmp.execution_stream_); + std::swap(this->stream_priority_, tmp.stream_priority_); + std::swap(this->scheduling_priority_, tmp.scheduling_priority_); std::swap(this->annotated_, tmp.annotated_); // Note: Make sure all tensor dist attr has the same process_mesh set_process_mesh(this->process_mesh_); @@ -349,6 +352,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) { impl_idx_ = 0; is_recompute_ = false; execution_stream_ = kDefault; + stream_priority_ = 0; scheduling_priority_ = 0; } @@ -361,6 +365,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { set_impl_idx(dist_attr.impl_idx()); set_is_recompute(dist_attr.is_recompute()); set_execution_stream(dist_attr.execution_stream()); + set_stream_priority(dist_attr.stream_priority()); set_scheduling_priority(dist_attr.scheduling_priority()); set_annotated(dist_attr.annotated()); } @@ -599,6 +604,7 @@ std::string OperatorDistAttr::to_string() const { str += "{impl_type: " + impl_type_ + ", "; str += "impl_idx: " + std::to_string(impl_idx_) + ", "; str += "execution_stream: " + execution_stream_ + ", "; + str += "stream_priority: " + std::to_string(stream_priority_) + ", "; str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", "; str += "annotated: [" + str_join(annotated_) + "], "; str += "\nprocess_mesh: " + process_mesh_.to_string() + ", "; @@ -684,6 +690,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) { if (lhs.execution_stream() != rhs.execution_stream()) { return false; } + if (lhs.stream_priority() != rhs.stream_priority()) { + return false; + } if (lhs.scheduling_priority() != rhs.scheduling_priority()) { return false; } diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.h b/paddle/fluid/distributed/auto_parallel/dist_attr.h index b38a21f336a..90bbdf30822 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.h +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.h @@ -222,6 +222,12 @@ class OperatorDistAttr { execution_stream_ = execution_stream; } + int stream_priority() const { return stream_priority_; } + + void set_stream_priority(int stream_priority) { + stream_priority_ = stream_priority; + } + int64_t scheduling_priority() const { return scheduling_priority_; } void set_scheduling_priority(int64_t scheduling_priority) { @@ -289,6 +295,7 @@ class OperatorDistAttr { int64_t impl_idx_ = 0; bool is_recompute_ = false; std::string execution_stream_ = kDefault; + int stream_priority_ = 0; // lower value, higher priority int64_t scheduling_priority_ = 0; // lower value, higher priority std::map annotated_; }; diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index c436f29b59f..1a6fef67923 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -46,7 +46,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( platform::EmplaceDeviceContexts( &fetch_ctxs_, places, - /*disable_setting_default_stream_for_allocator=*/true); + /*disable_setting_default_stream_for_allocator=*/true, + /*stream_priority=*/0); if (ir::IsTopologySortOperationsUnique(*graph_)) { VLOG(10) << "Change thread number to 1 because the toposort order is unique"; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 8a4bd32b155..4489cd9b00b 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -41,7 +41,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( platform::EmplaceDeviceContexts( &fetch_ctxs_, places, - /*disable_setting_default_stream_for_allocator=*/true); + /*disable_setting_default_stream_for_allocator=*/true, + /*stream_priority=*/0); if (strategy_.num_iteration_per_run_ > 1) { int read_op_num = 0; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f98acfdccfd..507f0302c57 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -666,21 +666,28 @@ bool BuildOpFuncList(const platform::Place& place, op_func_node.output_index = outs_name2id; const OperatorDistAttr* dist_attr = block.Op(i)->DistAttr(); - if (dist_attr && - dist_attr->execution_stream() != distributed::auto_parallel::kDefault) { - op_func_node.execution_stream_ = dist_attr->execution_stream(); - } - if (dist_attr) { - op_func_node.priority_ = dist_attr->scheduling_priority(); - } else if (interpreter::IsCommunicationOp(op_type)) { - // NOTE(Ruibiao): Dispatching computation before communication improves - // multi-stream overlap when the time cost of communication less than that - // of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 training). - op_func_node.priority_ = 1; + if (dist_attr->execution_stream() != + distributed::auto_parallel::kDefault) { + op_func_node.execution_stream_ = dist_attr->execution_stream(); + } + op_func_node.stream_priority_ = dist_attr->stream_priority(); + op_func_node.scheduling_priority_ = dist_attr->scheduling_priority(); + } else { + if (interpreter::IsCommunicationOp(op_type)) { + // NOTE(Ruibiao): Dispatching computation before communication improves + // multi-stream overlap when the time cost of communication less than + // that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 + // training). + op_func_node.scheduling_priority_ = 1; + } } - VLOG(6) << "scheduling priority of " << op_type << " : " - << op_func_node.priority_; + + VLOG(6) << op_type + << " : [execution_stream, stream_priority, scheduling_priority] = [" + << op_func_node.execution_stream_ << ", " + << op_func_node.stream_priority_ << ", " + << op_func_node.scheduling_priority_ << "]"; SingleStreamGuard single_stream_guard(ops[i]); diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 8d700f51012..dde504280d2 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -39,7 +39,9 @@ class ContextManager { } std::shared_future> Get( - const std::string& type, const platform::Place& place) { + const std::string& type, + const platform::Place& place, + int stream_priority) { std::lock_guard lk(ctx_mtx_); VLOG(6) << "Get dev_ctx for " << type << " - " << place; @@ -48,7 +50,8 @@ class ContextManager { platform::EmplaceDeviceContexts( &ctxs, {place}, - /*disable_setting_default_stream_for_allocator=*/true); + /*disable_setting_default_stream_for_allocator=*/true, + stream_priority); } return ctxs[place]; } @@ -142,6 +145,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( auto& op = op_func_node.operator_base_; auto& op_type = op->Type(); const std::string& execution_stream = op_func_node.execution_stream_; + const int stream_priority = op_func_node.stream_priority_; ContextManager& ctx_manager = ContextManager::Instance(); // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is @@ -152,15 +156,21 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( << ", execution stream = " << execution_stream; if (execution_stream != kDefaultStream) { return ctx_manager - .Get(std::string(kCustomStream) + "-" + execution_stream, place_) + .Get(std::string(kCustomStream) + "-" + execution_stream, + place_, + stream_priority) .get() .get(); } if (op_type == interpreter::kMemcpyD2H) { - return ctx_manager.Get(std::string(kD2HStream), place_).get().get(); + return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority) + .get() + .get(); } else if (op_type == interpreter::kMemcpyH2D) { - return ctx_manager.Get(std::string(kH2DStream), place_).get().get(); + return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority) + .get() + .get(); } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 63525330ea6..537170e2890 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -139,13 +139,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place, } var_scope_.SetLocalScope(local_scope_); - instruction_prority_less = [this](size_t lhs, size_t rhs) { - Priority lhs_prority = vec_instruction_[lhs].GetPriority(); - Priority rhs_prority = vec_instruction_[rhs].GetPriority(); - if (lhs_prority == rhs_prority) { + instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) { + SchedulingPriority lhs_scheduling_priority = + vec_instruction_[lhs].GetSchedulingPriority(); + SchedulingPriority rhs_scheduling_priority = + vec_instruction_[rhs].GetSchedulingPriority(); + if (lhs_scheduling_priority == rhs_scheduling_priority) { return lhs < rhs; } - return lhs_prority > rhs_prority; + return lhs_scheduling_priority > rhs_scheduling_priority; }; PrepareForCUDAGraphCapture(); @@ -1089,7 +1091,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { // scheduling, the priority order involved cross-thread scheduling is not // guaranteed. Only Ops scheduled by the same AddTask call have the guarantee // of priority order. - SchedulingQueue ready_ops(instruction_prority_less); + SchedulingQueue ready_ops(instruction_scheduling_priority_less); ready_ops.push(instr_id); while (!ready_ops.empty()) { instr_id = ready_ops.top(); @@ -1427,7 +1429,7 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() { }; std::vector trace_order; - SchedulingQueue ready_ops(instruction_prority_less); + SchedulingQueue ready_ops(instruction_scheduling_priority_less); for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) { if (dependecy_count_[instr_id] == 0) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 53625c87938..cf7f3a13dcf 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -79,9 +79,11 @@ class InterpreterCore { const platform::Place& GetPlace() const { return place_; } private: - using InstructionPriorityLess = std::function; + using InstructionSchedulingPriorityLess = std::function; using SchedulingQueue = - std::priority_queue, InstructionPriorityLess>; + std::priority_queue, + InstructionSchedulingPriorityLess>; // build graph void Convert(std::vector* op_func_nodes); @@ -181,7 +183,7 @@ class InterpreterCore { int64_t sync_op_num_{-1}; std::vector trace_execute_order_; - InstructionPriorityLess instruction_prority_less; + InstructionSchedulingPriorityLess instruction_scheduling_priority_less; }; std::shared_ptr CreateInterpreterCore( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 3b437275b04..8c62e9d26d4 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -32,12 +32,12 @@ namespace framework { using OpKernelComputeFunc = std::function; -using Priority = int64_t; +using SchedulingPriority = int64_t; constexpr const char* kCoalesceTensor = "coalesce_tensor"; // stream types -constexpr const char* kCustomStream = "CustromStream"; +constexpr const char* kCustomStream = "CustomStream"; constexpr const char* kDefaultStream = "DefaultStream"; constexpr const char* kD2HStream = "D2HStream"; constexpr const char* kH2DStream = "H2DStream"; @@ -263,6 +263,7 @@ enum class OpFuncType { class RuntimeInferShapeContext; struct OpFuncNode { + int stream_priority_{0}; // lower value, higher priority // fit for phi kernel phi::Kernel* phi_kernel_{nullptr}; // not owned platform::DeviceContext* dev_ctx_; // not owned @@ -279,7 +280,7 @@ struct OpFuncNode { OpFuncType type_; OpKernelComputeFunc kernel_func_; - Priority priority_{0}; // lower value, higher priority + SchedulingPriority scheduling_priority_{0}; // lower value, higher priority }; class Instruction { @@ -369,7 +370,9 @@ class Instruction { void ClearInplace(); - Priority GetPriority() const { return op_func_node_.priority_; } + SchedulingPriority GetSchedulingPriority() const { + return op_func_node_.scheduling_priority_; + } private: bool is_artificial_; // Instruction is artificial means that it is only used diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 4ec96f606fa..4aecc6e0412 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -25,13 +25,13 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/expect.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #endif #ifdef PADDLE_WITH_MLU @@ -145,12 +145,37 @@ void DeviceContextPool::SetDeviceContexts( external_device_contexts_ = dev_ctxs; } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +typename std::enable_if::value, + DevCtx*>::type +ConstructDevCtx(const platform::Place& p, /*unused*/ int stream_priority = 0) { + return new DevCtx(p); +} + +template +typename std::enable_if::value, + DevCtx*>::type +ConstructDevCtx(const platform::Place& p, int stream_priority) { + return new DevCtx(p, /*init=*/true, stream_priority); +} +#else +template +DevCtx* ConstructDevCtx(const platform::Place& p, + /*unused*/ int stream_priority) { + return new DevCtx(p); +} +#endif + template std::unique_ptr CreateDeviceContext( const platform::Place& p, - bool disable_setting_default_stream_for_allocator = false) { + bool disable_setting_default_stream_for_allocator = false, + int stream_priority = 0) { using PtrType = std::unique_ptr; - auto* dev_ctx = new DevCtx(p); + + DevCtx* dev_ctx = ConstructDevCtx(p, stream_priority); + if (is_gpu_place(p)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto* cuda_ctx = dynamic_cast(dev_ctx); @@ -201,21 +226,24 @@ inline void EmplaceDeviceContext( std::map>>* place_to_device_context, platform::Place place, - bool disable_setting_default_stream_for_allocator) { + bool disable_setting_default_stream_for_allocator, + int stream_priority) { // lazy evaluation. i.e., only create device context at first `Get` place_to_device_context->emplace( place, std::async(std::launch::deferred, CreateDeviceContext, place, - disable_setting_default_stream_for_allocator)); + disable_setting_default_stream_for_allocator, + stream_priority)); } void EmplaceDeviceContexts( std::map>>* place_to_device_context, const std::vector& places, - bool disable_setting_default_stream_for_allocator) { + bool disable_setting_default_stream_for_allocator, + int stream_priority) { PADDLE_ENFORCE_GT( places.size(), 0, @@ -234,19 +262,22 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #endif } else if (platform::is_gpu_place(p)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + stream_priority); #else PADDLE_THROW( platform::errors::Unimplemented("CUDAPlace is not supported. Please " @@ -257,7 +288,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW(platform::errors::Unimplemented( "CUDAPlace is not supported. Please re-compile with WITH_GPU " @@ -268,7 +300,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW( platform::errors::Unimplemented("XPUPlace is not supported. Please " @@ -279,7 +312,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW( platform::errors::Unimplemented("MLUPlace is not supported. Please " @@ -290,7 +324,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW( platform::errors::Unimplemented("IPUPlace is not supported. Please " @@ -301,7 +336,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW(platform::errors::Unimplemented( "NPUPlace is not supported. Please " @@ -312,7 +348,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW(platform::errors::Unimplemented( "NPUPinnedPlace is not supported. Please re-compile with " @@ -324,7 +361,8 @@ void EmplaceDeviceContexts( EmplaceDeviceContext( place_to_device_context, p, - disable_setting_default_stream_for_allocator); + disable_setting_default_stream_for_allocator, + /*unused*/ stream_priority); #else PADDLE_THROW(platform::errors::Unimplemented( "CustomPlace is not supported. Please re-compile with " @@ -339,7 +377,8 @@ DeviceContextPool::DeviceContextPool( const std::vector& places) { EmplaceDeviceContexts(&device_contexts_, places, - /*disable_setting_default_stream_for_allocator=*/false); + /*disable_setting_default_stream_for_allocator=*/false, + /*stream_priority=*/0); } #ifdef PADDLE_WITH_IPU diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 94236fcff1a..8b944aef865 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -346,7 +346,8 @@ void EmplaceDeviceContexts( std::map>>* place_to_device_context, const std::vector& places, - bool disable_setting_default_stream_for_allocator); + bool disable_setting_default_stream_for_allocator, + int stream_priority); /*! \brief device context pool singleton */ class DeviceContextPool { diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index c650a008e3d..7b734b6d5b8 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -288,6 +288,9 @@ void BindAutoParallel(py::module *m) { .def_property("execution_stream", &OperatorDistAttr::execution_stream, &OperatorDistAttr::set_execution_stream) + .def_property("stream_priority", + &OperatorDistAttr::stream_priority, + &OperatorDistAttr::set_stream_priority) .def_property("scheduling_priority", &OperatorDistAttr::scheduling_priority, &OperatorDistAttr::set_scheduling_priority) diff --git a/paddle/fluid/pybind/cuda_streams_py.cc b/paddle/fluid/pybind/cuda_streams_py.cc index f805696138f..8898088596e 100644 --- a/paddle/fluid/pybind/cuda_streams_py.cc +++ b/paddle/fluid/pybind/cuda_streams_py.cc @@ -252,8 +252,6 @@ void BindCudaStream(py::module *m_ptr) { PADDLE_THROW(platform::errors::InvalidArgument( "Priority should be 1(high) or 2(normal) ")); } - auto prio = phi::CUDAStream::Priority(priority); - auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; if (place == nullptr) { int curr_device_id = platform::GetCurrentDeviceId(); @@ -261,7 +259,10 @@ void BindCudaStream(py::module *m_ptr) { place = &place_tmp; } - new (&self) phi::CUDAStream(*place, prio, stream_flag); + auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; + // seting priority 1(high) and 2(normal) correspond to the actual + // cuda stream priority -1 and 0. + new (&self) phi::CUDAStream(*place, priority - 2, stream_flag); #else PADDLE_THROW(platform::errors::Unavailable( "Class CUDAStream can only be initialized on the GPU platform.")); @@ -277,8 +278,6 @@ void BindCudaStream(py::module *m_ptr) { PADDLE_THROW(platform::errors::InvalidArgument( "Priority should be 1(high) or 2(normal) ")); } - auto prio = phi::CUDAStream::Priority(priority); - auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; int device_count = platform::GetGPUDeviceCount(); if (device < 0) { @@ -291,8 +290,11 @@ void BindCudaStream(py::module *m_ptr) { device)); } - new (&self) - phi::CUDAStream(platform::CUDAPlace(device), prio, stream_flag); + auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; + // seting priority 1(high) and 2(normal) correspond to the actual + // cuda stream priority -1 and 0. + new (&self) phi::CUDAStream( + platform::CUDAPlace(device), priority - 2, stream_flag); #else PADDLE_THROW(platform::errors::Unavailable( "Class CUDAStream can only be initialized on the GPU platform.")); @@ -302,13 +304,10 @@ void BindCudaStream(py::module *m_ptr) { py::arg("priority") = 2) .def("__init__", [](phi::CUDAStream &self) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - auto prio = phi::CUDAStream::Priority::kNormal; - auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; - int device_id = platform::GetCurrentDeviceId(); - - new (&self) - phi::CUDAStream(platform::CUDAPlace(device_id), prio, stream_flag); + auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking; + new (&self) phi::CUDAStream( + platform::CUDAPlace(device_id), /*priority=*/0, stream_flag); #else PADDLE_THROW(platform::errors::Unavailable( "Class CUDAStream can only be initialized on the GPU platform.")); diff --git a/paddle/phi/backends/gpu/cuda/cuda_info.cc b/paddle/phi/backends/gpu/cuda/cuda_info.cc index f6bff1c7b3c..1ef1327cd6d 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_info.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_info.cc @@ -196,6 +196,13 @@ std::array GetGpuMaxGridDimSize(int id) { return ret; } +std::pair GetGpuStreamPriorityRange() { + int least_priority, greatest_priority; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); + return std::make_pair(least_priority, greatest_priority); +} + const gpuDeviceProp &GetDeviceProperties(int id) { std::call_once(g_device_props_size_init_flag, [&] { int gpu_num = 0; diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 9623fc8af74..e87cb48807f 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -218,7 +218,7 @@ struct GPUContext::Impl { InitDnnWorkspace(); } - void PartialInitWithoutAllocator() { + void PartialInitWithoutAllocator(int stream_priority) { owned_ = true; stream_owned_ = true; backends::gpu::GPUDeviceGuard guard(place_.device); @@ -230,7 +230,7 @@ struct GPUContext::Impl { &max_threads_per_mp_, &max_threads_per_block_, &max_grid_dim_size_); - stream_ = new CUDAStream(place_); + stream_ = new CUDAStream(place_, stream_priority); } void PartialInitWithAllocator() { @@ -818,10 +818,10 @@ GPUContext::GPUContext(GPUContext&&) = default; GPUContext& GPUContext::operator=(GPUContext&&) = default; -GPUContext::GPUContext(const GPUPlace& place, bool init) +GPUContext::GPUContext(const GPUPlace& place, bool init, int stream_priority) : DeviceContext(), impl_(std::make_unique(place)) { if (init) { - impl_->PartialInitWithoutAllocator(); + impl_->PartialInitWithoutAllocator(stream_priority); } } @@ -1001,8 +1001,8 @@ void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) { impl_->workspace_ = handle; } -void GPUContext::PartialInitWithoutAllocator() { - impl_->PartialInitWithoutAllocator(); +void GPUContext::PartialInitWithoutAllocator(int stream_priority) { + impl_->PartialInitWithoutAllocator(stream_priority); } void GPUContext::PartialInitWithAllocator() { diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 0b34d95eaf0..a40938b61bb 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -81,7 +81,9 @@ class DnnWorkspaceHandle { class PADDLE_API GPUContext : public DeviceContext, public TypeInfoTraits { public: - explicit GPUContext(const GPUPlace& place, bool init = true); + explicit GPUContext(const GPUPlace& place, + bool init = true, + int stream_priority = 0); GPUContext(GPUContext&&); GPUContext& operator=(GPUContext&&); @@ -198,7 +200,7 @@ class PADDLE_API GPUContext : public DeviceContext, // Note that this is a trick implementation, which can be used to partially // initialize when the SetAllocator interface is not called. - void PartialInitWithoutAllocator(); + void PartialInitWithoutAllocator(int stream_priority = 0); // Note that this is a trick implementation that can be used to initialize // resources that require an Allocator when the SetAllocator interface is // called. diff --git a/paddle/phi/backends/gpu/gpu_info.h b/paddle/phi/backends/gpu/gpu_info.h index 0f3c984ce85..6ba8863bc39 100644 --- a/paddle/phi/backends/gpu/gpu_info.h +++ b/paddle/phi/backends/gpu/gpu_info.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/phi/backends/gpu/gpu_types.h" @@ -58,6 +59,8 @@ int GetCurrentDeviceId(); //! Get the maximum GridDim size for GPU buddy allocator. std::array GetGpuMaxGridDimSize(int); +std::pair GetGpuStreamPriorityRange(); + //! Get a list of device ids from environment variable or use all. std::vector GetSelectedDevices(); diff --git a/paddle/phi/backends/gpu/rocm/rocm_info.cc b/paddle/phi/backends/gpu/rocm/rocm_info.cc index 8e4e06af801..edc23479c92 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_info.cc +++ b/paddle/phi/backends/gpu/rocm/rocm_info.cc @@ -200,6 +200,13 @@ std::array GetGpuMaxGridDimSize(int id) { return ret; } +std::pair GetGpuStreamPriorityRange() { + int least_priority, greatest_priority; + PADDLE_ENFORCE_GPU_SUCCESS( + hipDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); + return std::make_pair(least_priority, greatest_priority); +} + const gpuDeviceProp &GetDeviceProperties(int id) { std::call_once(g_device_props_size_init_flag, [&] { int gpu_num = 0; diff --git a/paddle/phi/core/cuda_stream.h b/paddle/phi/core/cuda_stream.h index 160a31262b3..ff2f4846b17 100644 --- a/paddle/phi/core/cuda_stream.h +++ b/paddle/phi/core/cuda_stream.h @@ -35,12 +35,6 @@ namespace phi { // Currently, CudaStream is used in python-side API only class CUDAStream { public: - enum class Priority : uint8_t { - kNull = 0x0, - kHigh = 0x1, - kNormal = 0x2, - }; - enum class StreamFlag : uint8_t { kDefaultFlag = 0x0, kStreamNonBlocking = 0x1, @@ -50,29 +44,41 @@ class CUDAStream { CUDAStream(const Place& place, const Stream& stream) : place_(place), stream_(stream) {} CUDAStream(const Place& place, - const Priority& priority = Priority::kNormal, + const int priority = 0, const StreamFlag& flag = StreamFlag::kDefaultFlag) { place_ = place; gpuStream_t stream = nullptr; backends::gpu::GPUDeviceGuard guard(place_.device); - if (priority == Priority::kHigh) { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority( - &stream, static_cast(flag), -1)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority( - &stream, static_cast(flag), -1)); -#endif - } else if (priority == Priority::kNormal) { + + // Stream priorities follow a convention where lower numbers imply greater + // priorities + auto priority_range = backends::gpu::GetGpuStreamPriorityRange(); + int least_priority = priority_range.first; // 0 in V100 + int greatest_priority = priority_range.second; // -5 in V100 + + // NOTE(Ruibiao): Replacing the following `PADDLE_ENFORCE_EQ` with + // `PADDLE_ENFORCE` leads to a nvcc compile error. This is probably a bug. + PADDLE_ENFORCE_EQ( + priority <= least_priority && priority >= greatest_priority, + true, + phi::errors::InvalidArgument( + "Cannot create a stream with priority = %d because stream priority " + "must be inside the meaningful range [%d, %d].", + priority, + least_priority, + greatest_priority)); + #ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority( - &stream, static_cast(flag), 0)); + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority( + &stream, static_cast(flag), priority)); #else - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority( - &stream, static_cast(flag), 0)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority( + &stream, static_cast(flag), priority)); #endif - } - VLOG(10) << "CUDAStream " << stream; + + VLOG(10) << "Create CUDAStream " << stream + << " with priority = " << priority + << ", flag = " << static_cast(flag); stream_ = Stream(reinterpret_cast(stream)); owned_ = true; } diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py index e3d4bdbe0f3..116aa60d052 100644 --- a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py +++ b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py @@ -50,8 +50,10 @@ class TestCustomStream(unittest.TestCase): ops = prog.global_block().ops for op_index in op_index_for_stream1: ops[op_index].dist_attr.execution_stream = "s1" + ops[op_index].dist_attr.stream_priority = -1 for op_index in op_index_for_stream2: ops[op_index].dist_attr.execution_stream = "s2" + ops[op_index].dist_attr.stream_priority = -2 def run_program(self, apply_custom_stream=False): paddle.seed(2022) -- GitLab