From f9591bb172e7274a77bfdcb6493579824aec8b47 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 8 Oct 2021 18:06:26 +0800 Subject: [PATCH] Support CUDA Graph on ParallelExecutor (#36250) * support CUDA Graph on PE * add ut, fix CI compile * reduce memory consumption * fix CUDA 10 CI * improve coverage * improve python coverage --- .../fluid/framework/details/build_strategy.h | 2 + .../details/scale_loss_grad_op_handle.cc | 19 ++- .../details/scale_loss_grad_op_handle.h | 6 + .../scope_buffered_ssa_graph_executor.cc | 53 ++++--- .../scope_buffered_ssa_graph_executor.h | 2 +- .../framework/distributed_strategy.proto | 1 + .../multi_devices_graph_pass/CMakeLists.txt | 2 +- .../modify_op_lock_and_record_event_pass.cc | 14 +- paddle/fluid/framework/parallel_executor.cc | 143 ++++++++++++++++++ paddle/fluid/framework/parallel_executor.h | 2 + paddle/fluid/operators/conv_cudnn_helper.h | 3 + paddle/fluid/platform/cuda_graph.cc | 12 ++ paddle/fluid/platform/cuda_graph.h | 10 +- .../platform/cuda_graph_with_memory_pool.cc | 9 +- paddle/fluid/platform/gpu_info.cc | 2 +- paddle/fluid/pybind/pybind.cc | 27 +++- python/paddle/fluid/executor.py | 12 +- .../fluid/tests/unittests/test_cuda_graph.py | 91 ++++++++++- 18 files changed, 368 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 0629f1b915..25110fe24f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -143,6 +143,8 @@ struct BuildStrategy { // Turn off inplace addto by default. bool enable_addto_{false}; + bool allow_cuda_graph_capture_{false}; + // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // num_trainers is 1, so the current fields of build_strategy doesn't tell if // it's distributed model. diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index c0c3e14c8b..1e3cd4f0aa 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -86,19 +86,28 @@ struct ScaleLossGradFunctor { } }; +std::string ScaleLossGradOpHandle::LossGradName() const { + return static_cast(this->outputs_[0])->name(); +} + void ScaleLossGradOpHandle::RunImpl() { platform::RecordEvent record_event(Name()); - // Doesn't wait any event - std::string var_name = static_cast(this->outputs_[0])->name(); + RunOnVar(local_exec_scopes_[0]->FindVar(LossGradName()), true); +} - auto *tensor = - local_exec_scopes_[0]->FindVar(var_name)->GetMutable(); +void ScaleLossGradOpHandle::RunOnVar(Variable *var, bool record_event) { + auto *tensor = var->GetMutable(); tensor->Resize(make_ddim({1})); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, this->dev_ctxes_.at(place_)); - this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); }); + if (record_event) { + this->RunAndRecordEvent( + [&] { framework::VisitDataType(out_dtype_, func); }); + } else { + framework::VisitDataType(out_dtype_, func); + } #else ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, nullptr); framework::VisitDataType(out_dtype_, func); diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 02e5aa8844..88fe02a749 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -46,6 +46,12 @@ struct ScaleLossGradOpHandle : public OpHandleBase { std::string Name() const override; + platform::Place GetPlace() const { return place_; } + + void RunOnVar(Variable *var, bool record_event = false); + + std::string LossGradName() const; + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index ad47846c59..5d271d06b6 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -22,7 +22,9 @@ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/profiler.h" + namespace paddle { namespace framework { namespace details { @@ -49,8 +51,29 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( PrepareLocalExeScopes(); } +static void RunProgramDescs(const ProgramDescs &programs, + const std::vector &local_exec_scopes, + const std::vector &places) { + for (auto &program : programs) { + for (auto &op_desc : program.Block(0).AllOps()) { + for (size_t i = 0; i < local_exec_scopes.size(); ++i) { + auto op = OpRegistry::CreateOp(*op_desc); + op->Run(*local_exec_scopes[i], places[i]); + } + } + } +} + FetchResultType ScopeBufferedSSAGraphExecutor::Run( const std::vector &fetch_tensors, bool return_merged) { +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + strategy_.num_iteration_per_drop_scope_ = + std::numeric_limits::max(); + DropLocalExeScopes(/*need_wait=*/false); + } +#endif + if (drop_scope_counter_ == 0) { platform::RecordEvent e("InitLocalVars"); InitVariables(); @@ -84,7 +107,7 @@ FetchResultType ScopeBufferedSSAGraphExecutor::Run( ++drop_scope_counter_; if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ || DropScopeOrNot()) { - DropLocalExeScopes(); + DropLocalExeScopes(!platform::IsCUDAGraphCapturing()); } if (VLOG_IS_ON(5)) { @@ -128,15 +151,7 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() { if (graph.Has(details::kStartupProgramDescs)) { auto &program_descs = graph.Get(details::kStartupProgramDescs); - - for (auto &program_desc : program_descs) { - for (auto &op_desc : program_desc.Block(0).AllOps()) { - for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { - auto op = OpRegistry::CreateOp(*op_desc); - op->Run(*local_exec_scopes_[i], places_[i]); - } - } - } + RunProgramDescs(program_descs, local_exec_scopes_, places_); } is_initialized_ = true; } @@ -144,23 +159,17 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() { if (graph.Has(details::kProgramDescs)) { auto &program_descs = graph.Get(details::kProgramDescs); - - for (auto &program_desc : program_descs) { - for (auto &op_desc : program_desc.Block(0).AllOps()) { - for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { - auto op = OpRegistry::CreateOp(*op_desc); - op->Run(*local_exec_scopes_[i], places_[i]); - } - } - } + RunProgramDescs(program_descs, local_exec_scopes_, places_); } } -void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { +void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes(bool need_wait) { platform::RecordEvent drop_scope_event("DropLocalExeScopes"); drop_scope_counter_ = 0; - for (auto &p : places_) { - platform::DeviceContextPool::Instance().Get(p)->Wait(); + if (need_wait) { + for (auto &p : places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } } scope_monitor_.ClearHistoryLocalExecScopes(); for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index aa2b113c96..ea5a3c0795 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -53,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { FetchResultType Run(const std::vector& fetch_tensors, bool return_merged) override; - void DropLocalExeScopes(); + void DropLocalExeScopes(bool need_wait = true); bool NeedCreateLocalExeScope(); diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 17d15a94c7..e7a25de96a 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -115,6 +115,7 @@ message BuildStrategy { optional bool enable_auto_fusion = 11 [ default = false ]; optional bool enable_addto = 12 [ default = false ]; optional bool fix_op_run_order = 13 [ default = false ]; + optional bool allow_cuda_graph_capture = 14 [ default = false ]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt index 6764799d82..fea12baf06 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) +cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle op_graph_view multi_devices_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper) diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc index 70b95c9154..afd80e45cf 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" @@ -21,14 +22,23 @@ namespace paddle { namespace framework { namespace ir { +template +static bool IsMatchedPlaceSingleDeviceOp(details::OpHandleBase *op_base, + const platform::Place &place) { + auto *op = dynamic_cast(op_base); + return op && op->GetPlace() == place; +} + static bool IsLockAndRecordEventFreeComputationOpHandle( details::ComputationOpHandle *op, const OpGraphView &graph_view) { if (!platform::is_gpu_place(op->GetPlace()) && !platform::is_xpu_place(op->GetPlace())) return false; for (auto &pending_op : graph_view.PendingOps(op)) { - auto *tmp = dynamic_cast(pending_op); - if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { + if (!IsMatchedPlaceSingleDeviceOp( + pending_op, op->GetPlace()) && + !IsMatchedPlaceSingleDeviceOp( + pending_op, op->GetPlace())) { return false; } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index adbbfb380b..d19ac0b65f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" @@ -34,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/profiler.h" @@ -43,6 +45,10 @@ limitations under the License. */ DECLARE_double(eager_delete_tensor_gb); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +DECLARE_bool(sync_nccl_allreduce); +#endif + #ifdef WITH_GPERFTOOLS #include "gperftools/profiler.h" #endif @@ -669,6 +675,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // ncclOp std::vector async_graphs = CompileGraphWithBuildStrategy(graph, &graphs, loss_var_name); + PrepareForCUDAGraphCapture(graph); graph = member_->ApplyMemoryOptimizePass(graph); async_graphs[0] = graph; @@ -882,6 +889,23 @@ void ParallelExecutor::BCastParamsToDevices( FetchResultType ParallelExecutor::Run( const std::vector &fetch_tensors, bool return_merged) { VLOG(3) << "enter ParallelExecutor Run"; +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_tensors.empty(), true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + PADDLE_ENFORCE_EQ( + member_->build_strategy_.allow_cuda_graph_capture_, true, + platform::errors::InvalidArgument( + "You must turn on build_strategy.allow_cuda_graph_capture = True " + "to enable CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + member_->places_[0], platform::CUDAGraphCapturingPlace(), + platform::errors::InvalidArgument("The place to capture CUDAGraph is " + "not the same as the place to run.")); + } +#endif + #ifdef WITH_GPERFTOOLS if (gProfileStarted) { ProfilerFlush(); @@ -932,6 +956,16 @@ void ParallelExecutor::SkipMemoryReuse( void ParallelExecutor::FeedTensorsIntoLocalScopes( const std::vector> &tensors) { + if (platform::IsCUDAGraphCapturing()) { + for (auto &tensor : tensors) { + PADDLE_ENFORCE_EQ( + tensor.empty(), true, + platform::errors::PermissionDenied( + "Feeding data is not permitted when capturing CUDA Graph.")); + } + return; + } + if (!member_->AllowPartialFeed()) { PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(), platform::errors::Unimplemented( @@ -987,6 +1021,14 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors) { + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ( + tensors.empty(), true, + platform::errors::PermissionDenied( + "Feeding data is not permitted when capturing CUDA Graph.")); + return; + } + size_t num_places = member_->places_.size(); bool allow_partial_feed = member_->AllowPartialFeed(); @@ -1568,6 +1610,107 @@ const ir::Graph &ParallelExecutor::Graph() const { return member_->executor_->Graph(); } +void ParallelExecutor::PrepareForCUDAGraphCapture(ir::Graph *graph) { + const auto &build_strategy = member_->build_strategy_; + if (!build_strategy.allow_cuda_graph_capture_) return; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ( + build_strategy.async_mode_, false, + platform::errors::InvalidArgument( + "Async Executor does not support CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + platform::IsCUDAGraphCapturing(), false, + platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " + "when running the first batch.")); + PADDLE_ENFORCE_EQ( + member_->places_.size(), 1, + platform::errors::InvalidArgument( + "CUDA Graph is only supported when one GPU device is running.")); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(member_->places_[0]), true, + platform::errors::InvalidArgument( + "CUDA Graph is only supported on NVIDIA GPU device.")); + PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce, false, + platform::errors::InvalidArgument( + "FLAGS_sync_nccl_allreduce must be False to support " + "CUDA Graph capturing.")); + + std::unordered_map> all_vars; + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + auto *var_desc = node->Var(); + all_vars[var_desc->Name()].emplace_back(var_desc); + } + } + + auto mark_var_as_persistable = [&all_vars](const std::string &name) { + auto iter = all_vars.find(name); + if (iter != all_vars.end()) { + for (auto *var_desc : iter->second) { + var_desc->SetPersistable(true); + } + } + }; + + // Step 1: All fused vars must be persistable. + if (graph->Has(details::kFusedVars)) { + auto &fused_vars = graph->Get(details::kFusedVars); + for (auto &fused_var : fused_vars) { + fused_var.second.persistable_ = true; + mark_var_as_persistable(fused_var.first); + } + } + + // Step 2: All pinned vars must be persistable. + if (graph->Has(details::kPinnedVars)) { + auto &pinned_vars = graph->Get(details::kPinnedVars); + for (auto &pinned_var : pinned_vars) { + mark_var_as_persistable(pinned_var); + } + } + + // Step 3: Move all main programs to startup programs to make sure that + // the main programs would only be run once. + if (graph->Has(details::kProgramDescs)) { + auto &startup_programs = + graph->GetOrInit(details::kStartupProgramDescs); + auto &main_programs = + graph->Get(details::kProgramDescs); + for (auto &main_program : main_programs) { + startup_programs.emplace_back(main_program); + } + graph->Erase(details::kProgramDescs); + } + + // Step 4: Mark all vars in startup programs to be persistable. + if (graph->Has(details::kStartupProgramDescs)) { + auto &startup_programs = + graph->GetOrInit(details::kStartupProgramDescs); + for (auto &startup_program : startup_programs) { + for (auto &op_desc : startup_program.Block(0).AllOps()) { + for (auto &output : op_desc->OutputArgumentNames()) { + mark_var_as_persistable(output); + } + } + } + } + + // Step 5: ScaleLossGrad must be run beforehand to avoid H2D copy. + auto ops = ir::FilterByNodeWrapper(*graph); + auto *scope = member_->local_scopes_[0]; + for (auto *op : ops) { + auto *loss_grad_op = dynamic_cast(op); + if (loss_grad_op == nullptr) continue; + auto loss_grad_name = loss_grad_op->LossGradName(); + mark_var_as_persistable(loss_grad_name); + loss_grad_op->RunOnVar(scope->Var(loss_grad_name)); + loss_grad_op->SetSkipRunning(true); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported on NVIDIA GPU device.")); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 6c871a8d85..78774f0489 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -144,6 +144,8 @@ class ParallelExecutor { void SetReaderOpDeviceInfoOfGraphs( const std::vector &final_graphs); + void PrepareForCUDAGraphCapture(ir::Graph *graph); + ParallelExecutorPrivate *member_; std::vector> async_graphs_; std::vector var_infos_; diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 4c0ef02074..f4183bf570 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/cudnn_desc.h" namespace paddle { namespace operators { @@ -480,6 +481,7 @@ struct SearchAlgorithm { static algo_t Find(const ConvArgs& args, bool exhaustive_search, bool deterministic, const framework::ExecutionContext& ctx) { + platform::CUDAGraphCaptureModeGuard guard; auto dtype = platform::CudnnDataType::type; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size = 0; @@ -601,6 +603,7 @@ struct SearchAlgorithm { } static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + platform::CUDAGraphCaptureModeGuard guard; size_t workspace_size = 0; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( diff --git a/paddle/fluid/platform/cuda_graph.cc b/paddle/fluid/platform/cuda_graph.cc index 6e518d779e..693a592799 100644 --- a/paddle/fluid/platform/cuda_graph.cc +++ b/paddle/fluid/platform/cuda_graph.cc @@ -70,6 +70,9 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, cudaStreamCaptureStatus status; PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo( capturing_graph_->stream_, &status, &(capturing_graph_->id_))); + PADDLE_ENFORCE_EQ(IsValidCapturing(), true, + platform::errors::PermissionDenied( + "CUDA Graph should not be invalidated.")); VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_; } @@ -88,5 +91,14 @@ std::unique_ptr CUDAGraph::EndCapture() { #endif } +bool CUDAGraph::IsValidCapturing() { + if (!IsCapturing()) return false; + cudaStreamCaptureStatus status; + CUDAGraphID id; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id)); + return status == cudaStreamCaptureStatusActive; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph.h b/paddle/fluid/platform/cuda_graph.h index 41e36049aa..55ec463556 100644 --- a/paddle/fluid/platform/cuda_graph.h +++ b/paddle/fluid/platform/cuda_graph.h @@ -84,6 +84,10 @@ class CUDAGraph { return capturing_graph_->place_; } + // This API can be used to debug which GPU operation is not + // supported during capturing CUDA Graph. + static bool IsValidCapturing(); + private: #if CUDA_VERSION >= 10010 cudaGraph_t graph_{nullptr}; @@ -104,7 +108,8 @@ class CUDAGraphCaptureModeGuard { DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); public: - explicit CUDAGraphCaptureModeGuard(cudaStreamCaptureMode mode) { + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { if (UNLIKELY(CUDAGraph::IsCapturing())) { PADDLE_ENFORCE_CUDA_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); // After cudaThreadExchangeStreamCaptureMode is called, @@ -128,7 +133,8 @@ class CUDAGraphCaptureModeGuard { DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); public: - explicit CUDAGraphCaptureModeGuard(cudaStreamCaptureMode) {} + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} }; #endif diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index 1f0d39e2ab..4804d3f6ed 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -22,8 +22,10 @@ namespace platform { #ifdef PADDLE_WITH_CUDA void BeginCUDAGraphCapture(platform::CUDAPlace place, cudaStreamCaptureMode mode) { - auto stream = - platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + + auto stream = dev_ctx->stream(); CUDAGraph::BeginCapture(place, stream, mode); auto id = CUDAGraph::CapturingID(); memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( @@ -35,6 +37,9 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, } std::unique_ptr EndCUDAGraphCapture() { + auto place = CUDAGraph::CapturingPlace(); + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); return CUDAGraph::EndCapture(); } #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 59e4404ffe..c624ba94b7 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -558,7 +558,7 @@ class RecordedCudaMallocHelper { #ifdef PADDLE_WITH_HIP auto result = hipMalloc(ptr, size); #else - CUDAGraphCaptureModeGuard capture_mode_guard{cudaStreamCaptureModeRelaxed}; + CUDAGraphCaptureModeGuard capture_mode_guard; auto result = cudaMalloc(ptr, size); #endif if (result == gpuSuccess) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6b24c64492..f58c2a5db3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -736,6 +736,17 @@ PYBIND11_MODULE(core_noavx, m) { paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) + .def("_copy_from", + [](framework::Tensor &self, const framework::Tensor &other, + const platform::Place &place, int64_t batch_size) { + if (batch_size < 0) { + framework::TensorCopy(other, place, &self); + } else { + auto sliced = other.Slice(0, batch_size); + framework::TensorCopy(sliced, place, &self); + } + }, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, @@ -2299,7 +2310,14 @@ All parameter, weight, gradient are variables in Paddle. m.def("op_support_gpu", OpSupportGPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("get_cuda_device_count", platform::GetCUDADeviceCount); - m.def("cuda_empty_cache", platform::EmptyCache); + m.def("cuda_empty_cache", [] { + for (int dev_id : platform::GetSelectedDevices()) { + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace( + platform::CUDAPlace(dev_id)); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + } + platform::EmptyCache(); + }); m.def("get_device_properties", [](int id) -> const gpuDeviceProp & { return platform::GetDeviceProperties(id); @@ -3211,6 +3229,13 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, bool fix_op_run_order) { self.fix_op_run_order_ = fix_op_run_order; }) + .def_property("allow_cuda_graph_capture", + [](const BuildStrategy &self) { + return self.allow_cuda_graph_capture_; + }, + [](BuildStrategy &self, bool allow_cuda_graph_capture) { + self.allow_cuda_graph_capture_ = allow_cuda_graph_capture; + }) .def("_copy", [](const BuildStrategy &self) { auto new_bs = self; diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4c7537d8d5..8c118f31cb 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1044,9 +1044,15 @@ class Executor(object): lr_value = lr_sheduler() lr_var = program._program.global_block().vars[lr_sheduler._var_name] lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype) - exe.feed_and_split_tensor_into_local_scopes({ - lr_sheduler._var_name: lr_tensor - }) + if core.is_cuda_graph_capturing(): + warnings.warn( + "Caution!!! When capturing CUDA Graph, the learning rate scheduler would not " + "take any effect! Please set the learning rate manually before each batch!" + ) + else: + exe.feed_and_split_tensor_into_local_scopes({ + lr_sheduler._var_name: lr_tensor + }) fetch_var_names = list(map(_to_name_str, fetch_list)) tensors = exe.run(fetch_var_names, return_merged)._move_to_list() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py index 272d68e17f..7d13174735 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -17,18 +17,105 @@ import paddle.fluid as fluid from paddle.device.cuda.graphs import CUDAGraph import unittest import numpy as np +from paddle.fluid.dygraph.base import switch_to_static_graph +from simple_nets import simple_fc_net_with_inputs class TestCUDAGraph(unittest.TestCase): def setUp(self): - fluid.set_flags({'FLAGS_allocator_strategy': 'auto_growth'}) + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm( + ): + fluid.set_flags({ + 'FLAGS_allocator_strategy': 'auto_growth', + 'FLAGS_sync_nccl_allreduce': False, + 'FLAGS_cudnn_deterministic': True + }) def random_tensor(self, shape): return paddle.to_tensor( np.random.randint( low=0, high=10, size=shape).astype("float32")) - def test_cuda_graph(self): + @switch_to_static_graph + def test_cuda_graph_static_graph(self): + if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + return + + seed = 100 + loss_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=True) + loss_no_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=False) + self.assertEqual(loss_cuda_graph, loss_no_cuda_graph) + + def cuda_graph_static_graph_main(self, seed, use_cuda_graph): + batch_size = 1 + class_num = 10 + image_shape = [batch_size, 784] + label_shape = [batch_size, 1] + + paddle.seed(seed) + np.random.seed(seed) + startup = paddle.static.Program() + main = paddle.static.Program() + with paddle.static.program_guard(main, startup): + image = paddle.static.data( + name="image", shape=image_shape, dtype='float32') + label = paddle.static.data( + name="label", shape=label_shape, dtype='int64') + image.persistable = True + label.persistable = True + loss = simple_fc_net_with_inputs(image, label, class_num) + loss.persistable = True + lr = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]) + optimizer = paddle.optimizer.SGD(learning_rate=lr) + optimizer.minimize(loss) + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup) + build_strategy = paddle.static.BuildStrategy() + build_strategy.allow_cuda_graph_capture = True + build_strategy.fix_op_run_order = True + build_strategy.fuse_all_optimizer_ops = True + compiled_program = paddle.static.CompiledProgram( + main).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + places=place) + image_t = scope.var(image.name).get_tensor() + label_t = scope.var(label.name).get_tensor() + loss_t = scope.var(loss.name).get_tensor() + lr_var = main.global_block().var(lr._var_name) + self.assertTrue(lr_var.persistable) + lr_t = scope.var(lr_var.name).get_tensor() + cuda_graph = None + for batch_id in range(20): + image_t.set( + np.random.rand(*image_shape).astype('float32'), place) + label_t.set(np.random.randint( + low=0, high=class_num, size=label_shape, dtype='int64'), + place) + + if batch_id == 1 and use_cuda_graph: + cuda_graph = CUDAGraph(place, mode="global") + cuda_graph.capture_begin() + exe.run(compiled_program) + cuda_graph.capture_end() + + if cuda_graph: + lr_t.set(np.array([lr()], dtype='float32'), place) + cuda_graph.replay() + else: + exe.run(compiled_program) + lr.step() + if cuda_graph: + cuda_graph.reset() + return np.array(loss_t) + + def test_cuda_graph_dynamic_graph(self): if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): return -- GitLab