diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index ce471d55b24a18a7f291db0dc6d93026940096a2..8b5c3c179878090e450d7aa7eeecc6b67b1b3c72 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -46,6 +46,12 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( VLOG(10) << "Change thread number to 1 because the toposort order is unique"; strategy_.num_threads_ = 1; + traced_ops_.clear(); + for (auto *op_node : TopologySortOperations(*graph_)) { + if (op_node->IsWrappedBy()) { + traced_ops_.emplace_back(&(op_node->Wrapper())); + } + } } pool_.reset(new ::ThreadPool(strategy.num_threads_)); for (auto &op : ir::FilterByNodeWrapper(*graph_)) { diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 0ab4bd5a12b0687baf92e2d1544dfb36ff6a3b70..50a41cb5611e10be4d7a445cbff6ca051d895913 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -137,6 +137,31 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() { return g_exe_cache_info_map; } +static PEAndGraphPair CreateExecutorInfo( + const ProgramDesc &program_desc, const platform::Place &place, + int64_t start_op_index, int64_t end_op_index, framework::Scope *scope, + const details::BuildStrategy &build_strategy) { + auto execution_strategy = details::GetExecutionStrategy(place); + auto graph = std::make_shared( + program_desc, start_op_index, end_op_index); + auto parallel_executor = std::make_shared( + place, scope, execution_strategy, build_strategy, graph.get()); + parallel_executor->PrepareVariables(scope); + return std::make_pair(parallel_executor, graph); +} + +PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc &program_desc, + const platform::Place &place, + int64_t start_op_index, + int64_t end_op_index, + framework::Scope *scope) { + details::BuildStrategy build_strategy; + build_strategy.fix_op_run_order_ = true; + auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index, + end_op_index, scope, build_strategy); + return pe_and_graph; +} + CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc, const platform::Place &place, int64_t start_op_index, int64_t end_op_index, @@ -153,21 +178,17 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc, } VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad; - auto execution_strategy = details::GetExecutionStrategy(place); auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id); // 2. Construct Graph and ParallelExecutor. - auto graph = std::make_shared( - program_desc, start_op_index, end_op_index); - auto parallel_executor = std::make_shared( - place, scope, execution_strategy, build_strategy, graph.get()); - parallel_executor->PrepareVariables(scope); + auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index, + end_op_index, scope, build_strategy); // 3. Insert value into cached map. auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad); - cached_value.executor_ = parallel_executor; - cached_value.graph_ = std::move(graph); - return std::make_pair(parallel_executor, /*is_new_created=*/true); + cached_value.executor_ = pe_and_graph.first; + cached_value.graph_ = pe_and_graph.second; + return std::make_pair(pe_and_graph.first, /*is_new_created=*/true); } else { VLOG(1) << "get exe_info from cache by: " << program_id << " is_grad: " << is_grad; diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 8207b56fc04f16fa13746a011fa602a65f0b52a0..25c0bfab90c4af49e4765249d977beb50840ffd6 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -127,11 +127,20 @@ class ExecutorInfoCache { using CacheInfo = std::pair, bool /*is_new_created*/>; +using PEAndGraphPair = + std::pair, std::shared_ptr>; + CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc, const platform::Place& place, int64_t start_op_index, int64_t end_op_index, bool is_grad, int64_t program_id, framework::Scope* scope); +PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc, + const platform::Place& place, + int64_t start_op_index, + int64_t end_op_index, + framework::Scope* scope); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index 401ccb03d78d6f14c21c803ec8209ef36ec83e54..ec664b4513f2cd41816985d289954779792e7ecf 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -41,6 +41,8 @@ #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif +#include "paddle/fluid/operators/cuda_graph_with_in_out.h" + namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 9fe67e1dcdff31c23f8febca53324dd01736dc70..463331494d908af90fc1ba2d08f0b945a4ffe891 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -87,6 +87,8 @@ namespace operators { class CudnnRNNCache; +class CUDAGraphWithInOuts; + namespace reader { class LoDTensorBlockingQueueHolder; class OrderedMultiDeviceLoDTensorBlockingQueueHolder; @@ -189,7 +191,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< #if defined(PADDLE_WITH_CNCL) cnclCliqueId, #endif - int, float, Vocab>; + std::vector>, int, float, + Vocab>; template struct VarTypeTrait { static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 46e1a500e48705b6cceeab566622d232cda445f3..7cd5fffea2ad6da4ec2ddd49e8750081ef88040e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -123,6 +123,8 @@ class CUDAGraphAllocator : underlying_allocator_(allocator) {} public: + ~CUDAGraphAllocator() { VLOG(10) << "CUDAGraphAllocator destructed"; } + static std::shared_ptr Create( const std::shared_ptr& allocator) { return std::shared_ptr(new CUDAGraphAllocator(allocator)); @@ -973,7 +975,7 @@ AllocatorFacade& AllocatorFacade::Instance() { AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const { #ifdef PADDLE_WITH_CUDA if (UNLIKELY(IsCUDAGraphCapturing())) { - auto id = platform::CUDAGraph::CapturingID(); + auto id = platform::CUDAGraph::CapturingPoolID(); auto iter = cuda_graph_map_.find(id); PADDLE_ENFORCE_NE( iter, cuda_graph_map_.end(), @@ -1116,7 +1118,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place, } #ifdef PADDLE_WITH_CUDA -void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { +void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth, platform::errors::InvalidArgument( "CUDA Graph is only supported when the " @@ -1124,23 +1126,32 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { "FLAGS_allocator_strategy=\"%s\"", FLAGS_allocator_strategy)); auto& allocator = cuda_graph_map_[id]; - PADDLE_ENFORCE_EQ( - allocator.get(), nullptr, - platform::errors::InvalidArgument( - "The memory pool of the CUDA Graph with ID %d have been prepared.", - id)); - allocator.reset(new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); - - VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; + auto& ref_cnt = cuda_graph_ref_cnt_[id]; + if (allocator.get() == nullptr) { + allocator.reset( + new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); + VLOG(10) << "Create memory pool for CUDA Graph with memory ID " << id; + } else { + VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id; + } + ++ref_cnt; } -void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { - auto iter = cuda_graph_map_.find(id); - PADDLE_ENFORCE_NE(iter, cuda_graph_map_.end(), +void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { + auto ref_cnt_iter = cuda_graph_ref_cnt_.find(id); + PADDLE_ENFORCE_NE(ref_cnt_iter, cuda_graph_ref_cnt_.end(), platform::errors::InvalidArgument( - "Cannot find CUDA Graph with ID = %d", id)); - cuda_graph_map_.erase(iter); - VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; + "Cannot find CUDA Graph with memory ID = %d", id)); + auto& ref_cnt = ref_cnt_iter->second; + --ref_cnt; + if (ref_cnt == 0) { + cuda_graph_map_.erase(id); + cuda_graph_ref_cnt_.erase(ref_cnt_iter); + VLOG(10) << "Remove memory pool of CUDA Graph with memory ID " << id; + } else { + VLOG(10) << "Decrease memory pool ID " << id << " reference count to be " + << ref_cnt; + } } #endif #endif diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 1dea50edccf2eb7b8efb196ba102a04874546683..94b07e3e6c1efc591b9e75475786fbb375dc771e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -89,8 +89,8 @@ class AllocatorFacade { #endif #ifdef PADDLE_WITH_CUDA - void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); - void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); + void PrepareMemoryPoolForCUDAGraph(int64_t id); + void RemoveMemoryPoolOfCUDAGraph(int64_t id); #endif // TODO(yy): Allocate a Copy-On-Write allocation? @@ -98,8 +98,9 @@ class AllocatorFacade { AllocatorFacade(); AllocatorFacadePrivate* m_; #ifdef PADDLE_WITH_CUDA - std::unordered_map> + std::unordered_map> cuda_graph_map_; + std::unordered_map cuda_graph_ref_cnt_; #endif }; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 3112d0d8205a86326b7ecd9b86ffac486291f6b3..b2fd59b47454e1a0020bc3f69bc9bec13a4e21e4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -107,6 +107,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) +target_link_libraries(run_program_op cuda_graph_with_memory_pool) op_library(quantize_linear_op DEPS cast_kernel) op_library(save_combine_op DEPS string_array) op_library(load_combine_op DEPS string_array) diff --git a/paddle/fluid/operators/cuda_graph_with_in_out.h b/paddle/fluid/operators/cuda_graph_with_in_out.h new file mode 100644 index 0000000000000000000000000000000000000000..e7a943aee4d364c51cbdf0ff83d32935731365fc --- /dev/null +++ b/paddle/fluid/operators/cuda_graph_with_in_out.h @@ -0,0 +1,156 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/tensor.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#endif + +namespace paddle { +namespace operators { + +#ifdef PADDLE_WITH_CUDA +class CUDAGraphWithInOuts { + public: + template + CUDAGraphWithInOuts(Callable &&callable, platform::CUDAPlace place, + const std::vector &in_ptrs, + cudaStreamCaptureMode mode, int64_t pool_id) { + in_indices_.resize(in_ptrs.size()); + ins_.reserve(in_ptrs.size()); + int64_t valid_in_idx = 0; + for (size_t i = 0; i < in_ptrs.size(); ++i) { + if (in_ptrs[i] == nullptr) { + in_indices_[i] = -1; + } else { + in_indices_[i] = (valid_in_idx++); + ins_.push_back(*in_ptrs[i]); + } + } + + platform::BeginCUDAGraphCapture(place, mode, pool_id); + auto out_ptrs = callable(in_ptrs); + graph_ = platform::EndCUDAGraphCapture(); + graph_->Replay(); + + out_indices_.resize(out_ptrs.size()); + outs_.reserve(out_ptrs.size()); + int64_t valid_out_idx = 0; + for (size_t i = 0; i < out_ptrs.size(); ++i) { + if (out_ptrs[i] == nullptr) { + out_indices_[i] = -1; + } else { + out_indices_[i] = (valid_out_idx++); + outs_.push_back(*out_ptrs[i]); + } + } + } + + void Run(const std::vector &ins) { + PADDLE_ENFORCE_EQ( + ins.size(), in_indices_.size(), + phi::errors::InvalidArgument("The input number does not match.")); + for (size_t i = 0; i < in_indices_.size(); ++i) { + if (in_indices_[i] >= 0) { + auto *dst = &ins_[in_indices_[i]]; + framework::TensorCopy(*ins[i], dst->place(), dst); + } + } + graph_->Replay(); + } + + std::vector GetOutputs() { + std::vector outs(out_indices_.size()); + for (size_t i = 0; i < out_indices_.size(); ++i) { + if (out_indices_[i] >= 0) { + outs[i] = &outs_[out_indices_[i]]; + } + } + return outs; + } + + int64_t PoolID() const { return graph_->PoolID(); } + + private: + std::unique_ptr graph_; + std::vector ins_; + std::vector outs_; + std::vector in_indices_; + std::vector out_indices_; +}; + +template +static std::unique_ptr CaptureCUDAGraph( + Callable &&callable, const framework::ExecutionContext &ctx, + const std::vector &input_names, + const std::vector &output_names, cudaStreamCaptureMode mode, + int64_t pool_id) { + std::vector inputs; + for (const auto &name : input_names) { + auto input_tensors = ctx.MultiInput(name); + inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end()); + } + + auto func = [&](const std::vector &inputs) { + callable(ctx); + std::vector outputs; + for (const auto &name : output_names) { + auto output_tensors = ctx.MultiOutput(name); + outputs.insert(outputs.end(), output_tensors.begin(), + output_tensors.end()); + } + return outputs; + }; + + return std::make_unique(func, ctx.GetPlace(), inputs, + mode, pool_id); +} + +static void ExecuteCUDAGraph(const framework::ExecutionContext &ctx, + const std::vector &input_names, + const std::vector &output_names, + CUDAGraphWithInOuts *graph) { + std::vector inputs; + for (const auto &name : input_names) { + auto input_tensors = ctx.MultiInput(name); + inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end()); + } + + graph->Run(inputs); + auto outputs = graph->GetOutputs(); + + size_t idx = 0; + for (const auto &name : output_names) { + auto output_tensors = ctx.MultiOutput(name); + for (auto *out_t : output_tensors) { + if (outputs[idx] != nullptr) { + *out_t = *outputs[idx]; + } else { + PADDLE_ENFORCE_EQ( + out_t, nullptr, + phi::errors::InvalidArgument( + "The %d-th output variable should be nullptr.", idx)); + } + ++idx; + } + } +} +#else +class CUDAGraphWithInOuts {}; +#endif + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 6af8c925ff580292438159e52eff884d7ac10232..482f88b73e616cd96b61beeb985ceb1785507b93 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" @@ -195,9 +196,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, size_t main_offset = size / (block_size * kVecSize) * (block_size * kVecSize); - VectorizedRandomGenerator<<>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment, main_offset); + PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL( + !is_fix_seed, (VectorizedRandomGenerator), grid_size, + block_size, 0, stream, offset, KERNEL_PARAMS.As(1), + KERNEL_PARAMS.As(7), size, seed_data, dropout_prob, x_data, + mask_data, y_data, upscale_in_train, increment, main_offset); } else { if (upscale_in_train) { // todo: can y share with data with x directly? diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index ec62feb07bc80711665fa8179a1a11cb040fa130..38c92de4523d5d7a481d8cb0de00506272f19243 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -90,6 +90,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { "computes double grad.") .AsDuplicable() .AsDispensable(); + AddOutput("CUDAGraph", "The output CUDA Graph when use_cuda_graph=True.") + .AsDispensable(); AddAttr("global_block", "(BlockDesc *)" "The global block of executed program desc."); @@ -107,6 +109,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { "program_id", "(int64_t)" "The unique hash id used as cache key for ExecutorInfoCache."); + AddAttr("cuda_graph_capture_mode", + "(str, default '') The CUDA Graph capture mode. " + "Default '' means no CUDA Graph capturing.") + .SetDefault(""); + AddAttr("cuda_graph_pool_id", + "(int64_t, default 0) The CUDA Graph memory pool ID.") + .SetDefault(0); AddComment(R"DOC( RunProgram operator. @@ -191,6 +200,9 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker { grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput("OutScope", this->Output("OutScope")); grad_op->SetInput("DOut", this->Output("DOut")); + if (this->HasOutput("CUDAGraph")) { + grad_op->SetInput("CUDAGraph", this->Output("CUDAGraph")); + } grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); auto block_desc = diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index fbc52480c826677d984b868ffc98866b13536b6a..8007f0bc37b1ff9a022ccaa130e83b236b97d88c 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -34,6 +34,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/operators/cuda_graph_with_in_out.h" +#endif DECLARE_bool(use_mkldnn); @@ -167,13 +170,84 @@ static void ShareVarsFromScope(const std::vector &vars, } } +#ifdef PADDLE_WITH_CUDA +static cudaStreamCaptureMode StringToCUDAGraphCaptureMode( + const std::string &mode) { + if (mode == "global") { + return cudaStreamCaptureModeGlobal; + } else if (mode == "thread_local") { + return cudaStreamCaptureModeThreadLocal; + } else if (mode == "relaxed") { + return cudaStreamCaptureModeRelaxed; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported CUDA Graph capture mode %s", mode)); + } +} +#endif + } // namespace details template class RunProgramOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + const auto &capture_mode = ctx.Attr("cuda_graph_capture_mode"); + auto is_test = ctx.Attr("is_test"); + if (capture_mode.empty()) { + ComputeImpl(ctx, is_test, false); + return; + } + +#ifdef PADDLE_WITH_CUDA + auto mode = details::StringToCUDAGraphCaptureMode(capture_mode); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + phi::errors::InvalidArgument("The cuda_graph_capture_mode is only " + "valid when using NVIDIA GPU.")); + auto *graph_var = ctx.OutputVar("CUDAGraph"); + PADDLE_ENFORCE_NOT_NULL( + graph_var, + phi::errors::InvalidArgument("Output(CUDAGraph) must exist when " + "cuda_graph_capture_mode is valid.")); + using GraphVecType = std::vector>; + auto &inner_graphs = *(graph_var->GetMutable()); + inner_graphs.resize(std::max(3, inner_graphs.size())); + size_t graph_idx = is_test ? 0 : 1; + if (inner_graphs[graph_idx].get() == nullptr) { + int64_t pool_id; + if (inner_graphs[1 - graph_idx].get() != nullptr) { + pool_id = inner_graphs[1 - graph_idx]->PoolID(); + } else { + pool_id = ctx.Attr("cuda_graph_pool_id"); + } + + framework::PEAndGraphPair pe_and_graph; + auto callable = [this, is_test, &pe_and_graph]( + const framework::ExecutionContext &exe_ctx) { + pe_and_graph = ComputeImpl(exe_ctx, is_test, true); + }; + inner_graphs[graph_idx] = CaptureCUDAGraph( + callable, ctx, {"X"}, {"Out", "DOut"}, mode, pool_id); + VLOG(10) << "Capture Forward CUDA Graph"; + } else { + VLOG(10) << "Run Forward CUDA Graph directly"; + ExecuteCUDAGraph(ctx, {"X"}, {"Out", "DOut"}, + inner_graphs[graph_idx].get()); + } +#else + PADDLE_THROW( + phi::errors::InvalidArgument("The cuda_graph_capture_mode is only " + "valid when using NVIDIA GPU.")); +#endif + } + + private: + framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx, + bool is_test, + bool use_cuda_graph) const { VLOG(2) << "RunProgramOpKernel Compute"; + framework::PEAndGraphPair pe_and_graph; // Step 1. prepare inputs, outputs, attrs auto &input_vars = ctx.MultiInputVar("X"); auto ¶m_vars = ctx.MultiInputVar("Params"); @@ -192,7 +266,6 @@ class RunProgramOpKernel : public framework::OpKernel { auto start_op_index = ctx.Attr("start_op_index"); auto end_op_index = ctx.Attr("end_op_index"); - auto is_test = ctx.Attr("is_test"); auto program_id = ctx.Attr("program_id"); // NOTE(chenweihang): In order not to add new variable type, use vector @@ -223,15 +296,29 @@ class RunProgramOpKernel : public framework::OpKernel { if (end_op_index > start_op_index) { auto *program = global_block->Program(); - auto cache_info = framework::GetExecutorInfoFromCache( - *program, ctx.GetPlace(), start_op_index, end_op_index, - /*is_grad=*/false, program_id, &scope); - auto ¶llel_executor = cache_info.first; + bool is_new_created; + if (use_cuda_graph) { + pe_and_graph = framework::CreateFixOrderExecutorInfo( + *program, ctx.GetPlace(), start_op_index, end_op_index, &scope); + is_new_created = true; + } else { + auto cache_info = framework::GetExecutorInfoFromCache( + *program, ctx.GetPlace(), start_op_index, end_op_index, + /*is_grad=*/false, program_id, &scope); + pe_and_graph.first = cache_info.first; + is_new_created = cache_info.second; + } + + auto ¶llel_executor = pe_and_graph.first; + // all out_vars are skip_eager_var + std::vector tmp_vars; auto &skip_eager_delete_vars = - framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( - program_id, false); - if (cache_info.second /*is_new_created*/) { + use_cuda_graph + ? tmp_vars + : framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, false); + if (is_new_created) { parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names); skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), output_var_names.begin(), @@ -263,6 +350,7 @@ class RunProgramOpKernel : public framework::OpKernel { #ifdef PADDLE_WITH_MKLDNN if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace()); #endif + return pe_and_graph; } }; @@ -270,14 +358,68 @@ template class RunProgramGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + const auto &capture_mode = ctx.Attr("cuda_graph_capture_mode"); + if (capture_mode.empty()) { + ComputeImpl(ctx, false); + return; + } + +#ifdef PADDLE_WITH_CUDA + auto mode = details::StringToCUDAGraphCaptureMode(capture_mode); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + phi::errors::InvalidArgument("The cuda_graph_capture_mode is only " + "valid when using NVIDIA GPU.")); + auto *graph_var = + const_cast(ctx.InputVar("CUDAGraph")); + PADDLE_ENFORCE_NOT_NULL( + graph_var, + phi::errors::InvalidArgument("Output(CUDAGraph) must exist when " + "cuda_graph_capture_mode is valid.")); + auto &inner_graphs = *( + graph_var + ->GetMutable>>()); + const size_t graph_idx = 2; + if (inner_graphs[graph_idx].get() == nullptr) { + framework::PEAndGraphPair pe_and_graph; + auto callable = + [this, &pe_and_graph](const framework::ExecutionContext &exe_ctx) { + pe_and_graph = ComputeImpl(exe_ctx, true); + }; + int64_t pool_id = inner_graphs[0].get() != nullptr + ? inner_graphs[0]->PoolID() + : inner_graphs[1]->PoolID(); + inner_graphs[graph_idx] = + CaptureCUDAGraph(callable, ctx, {framework::GradVarName("Out")}, + {framework::GradVarName("X")}, mode, pool_id); + VLOG(10) << "Capture Backward CUDA Graph"; + } else { + ExecuteCUDAGraph(ctx, {framework::GradVarName("Out")}, + {framework::GradVarName("X")}, + inner_graphs[graph_idx].get()); + VLOG(10) << "Run Backward CUDA Graph directly"; + } +#else + PADDLE_THROW( + phi::errors::InvalidArgument("The cuda_graph_capture_mode is only " + "valid when using NVIDIA GPU.")); +#endif + } + + private: + framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx, + bool use_cuda_graph) const { VLOG(2) << "RunProgramGradOpKernel Compute"; + framework::PEAndGraphPair pe_and_graph; // Step 1. prepare inputs and outputs auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out")); auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X")); auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params")); // if all output vars are set to stop_gradient, grad op no need to executed - if (input_grad_vars.empty() && param_grad_vars.empty()) return; + if (input_grad_vars.empty() && param_grad_vars.empty()) { + return pe_and_graph; + } auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out")); // NOTE: after PR22939 [Add double grad] merged, the grad op maker's @@ -321,15 +463,27 @@ class RunProgramGradOpKernel : public framework::OpKernel { if (end_op_index > start_op_index) { // Step 2. prepare executor and scope auto *program = global_block->Program(); - auto cache_info = framework::GetExecutorInfoFromCache( - *program, ctx.GetPlace(), start_op_index, end_op_index, - /*is_grad*/ true, program_id, &scope); - auto ¶llel_executor = cache_info.first; + bool is_new_created; + if (use_cuda_graph) { + pe_and_graph = framework::CreateFixOrderExecutorInfo( + *program, ctx.GetPlace(), start_op_index, end_op_index, &scope); + is_new_created = true; + } else { + auto cache_info = framework::GetExecutorInfoFromCache( + *program, ctx.GetPlace(), start_op_index, end_op_index, + /*is_grad*/ true, program_id, &scope); + pe_and_graph.first = cache_info.first; + is_new_created = cache_info.second; + } + auto ¶llel_executor = pe_and_graph.first; + std::vector tmp_vars; auto &skip_eager_delete_vars = - framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( - program_id, true); - if (cache_info.second /*is_new_created*/) { + use_cuda_graph + ? tmp_vars + : framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, true); + if (is_new_created) { parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, output_grad_var_names); @@ -360,6 +514,7 @@ class RunProgramGradOpKernel : public framework::OpKernel { global_inner_scope->DeleteScope(&scope); VLOG(2) << "The number of sub scopes after backward: " << global_inner_scope->kids().size(); + return pe_and_graph; } }; diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index 4804d3f6ed3016eb35b6688304e406375acf3615..c40a43dbfb876c0ee997580dd128bff19c0aac45 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -16,23 +16,33 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/device_context.h" +DECLARE_bool(use_stream_safe_cuda_allocator); + namespace paddle { namespace platform { #ifdef PADDLE_WITH_CUDA void BeginCUDAGraphCapture(platform::CUDAPlace place, - cudaStreamCaptureMode mode) { + cudaStreamCaptureMode mode, int64_t pool_id) { auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); auto stream = dev_ctx->stream(); CUDAGraph::BeginCapture(place, stream, mode); - auto id = CUDAGraph::CapturingID(); + + auto old_value = FLAGS_use_stream_safe_cuda_allocator; + if (old_value) { + FLAGS_use_stream_safe_cuda_allocator = false; + } + pool_id = CUDAGraph::SetMemoryPoolID(pool_id); memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( - id); - AddResetCallbackIfCapturingCUDAGraph([id] { + pool_id); + if (old_value) { + FLAGS_use_stream_safe_cuda_allocator = true; + } + AddResetCallbackIfCapturingCUDAGraph([pool_id] { memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( - id); + pool_id); }); } diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h index 7a9e1a3a1419ca62b794e53df0bd34b45dae8b9e..81b68a5c6786eaa45f930fe1c12e7a1c13f3dcd3 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.h +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -23,10 +23,53 @@ namespace paddle { namespace platform { +#ifdef PADDLE_WITH_CUDA +#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \ + __block, __sm_size, __stream, \ + __seed_inc, __seed_expr, \ + __offset_expr, ...) \ + do { \ + if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \ + using __Helper = \ + ::paddle::platform::IsSameKernelHelper; \ + auto *dev_ctx = \ + ::paddle::platform::DeviceContextPool::Instance().GetByPlace( \ + ::paddle::platform::CUDAGraph::CapturingPlace()); \ + auto __set_seed_func = [=]( \ + ::paddle::platform::CUDAKernelParams *__params, \ + bool __check_only) -> bool { \ + if (__check_only) { \ + return __params->func() == &__kernel_func && \ + __Helper::Compare(*__params, __VA_ARGS__); \ + } \ + auto &KERNEL_PARAMS = *__params; \ + uint64_t __seed, __offset; \ + ::paddle::operators::GetSeedDataAndIncrement( \ + *dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \ + __seed_expr = static_cast(__seed); \ + __offset_expr = static_cast(__offset); \ + return true; \ + }; \ + ::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \ + } \ + __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ + } while (0) +#else +#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \ + __block, __sm_size, __stream, \ + __seed_inc, __seed_expr, \ + __offset_expr, ...) \ + do { \ + __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ + } while (0) +#endif + // NOTE: These APIs are not thread-safe. #ifdef PADDLE_WITH_CUDA void BeginCUDAGraphCapture(platform::CUDAPlace place, - cudaStreamCaptureMode mode); + cudaStreamCaptureMode mode, + int64_t pool_id = CUDAGraph::kInvalidPoolID); std::unique_ptr EndCUDAGraphCapture(); #endif diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc index 8ee3b118c32f20c5facef51bffa81964981cf33d..ae6d90510f480433fba6069e8ca57b4b2d54fe3b 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#include +#include +#include namespace paddle { namespace platform { @@ -20,6 +23,69 @@ namespace platform { std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; +static std::vector ToposortCUDAGraph(cudaGraph_t graph) { + size_t num_nodes; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes)); + std::vector nodes(num_nodes); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); + + size_t num_edges; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaGraphGetEdges(graph, nullptr, nullptr, &num_edges)); + std::vector from(num_edges), to(num_edges); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaGraphGetEdges(graph, from.data(), to.data(), &num_edges)); + + std::unordered_map> + in_edges, out_edges; + for (auto node : nodes) { + in_edges[node]; + out_edges[node]; + } + + for (size_t i = 0; i < num_edges; ++i) { + in_edges[to[i]].insert(from[i]); + out_edges[from[i]].insert(to[i]); + } + + std::queue q; + for (const auto &pair : in_edges) { + if (pair.second.empty()) { + q.push(pair.first); + } + } + + nodes.clear(); + while (!q.empty()) { + auto cur = q.front(); + q.pop(); + nodes.push_back(cur); + + for (auto out_node : out_edges.at(cur)) { + auto &in_nodes = in_edges.at(out_node); + in_nodes.erase(cur); + if (in_nodes.empty()) { + q.push(out_node); + } + } + } + PADDLE_ENFORCE_EQ( + nodes.size(), num_nodes, + phi::errors::InvalidArgument("Toposort error, this may be a bug.")); + return nodes; +} + +CUDAGraphID CUDAGraph::UniqueID() { + static std::atomic id; + return id.fetch_add(1); +} + +int64_t CUDAGraph::UniqueMemoryPoolID() { + static std::atomic id(CUDAGraph::kDefaultPoolID + 1); + return id.fetch_add(1); +} + void CUDAGraph::Reset() { if (is_reset_) return; #if CUDA_VERSION >= 10010 @@ -46,9 +112,16 @@ void CUDAGraph::Replay() { PADDLE_ENFORCE_EQ(is_reset_, false, errors::PermissionDenied( "Cannot replay the CUDA Graph after reset is called.")); - for (auto exec_graph : exec_graphs_) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graph, stream_)); + size_t n = exec_graphs_.size(); + for (size_t i = 0; i < n; ++i) { + if (!is_first_run_) { + for (auto &hook : pre_hooks_[i]) { + hook(exec_graphs_[i]); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graphs_[i], stream_)); } + is_first_run_ = false; #endif } @@ -72,7 +145,8 @@ void CUDAGraph::BeginSegmentCapture() { platform::errors::PermissionDenied( "CUDA Graph should not be invalidated.")); VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_ - << ", segment id " << capturing_graph_->graphs_.size(); + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; #endif } @@ -112,15 +186,57 @@ void CUDAGraph::EndSegmentCapture() { if (num_nodes == 0) { PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph)); VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_ - << ", segment id " << capturing_graph_->graphs_.size(); + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; return; } + auto sorted_nodes = ToposortCUDAGraph(graph); + capturing_graph_->pre_hooks_.emplace_back(); + std::unordered_set visited; + VLOG(10) << "SetSeedFunc number : " + << capturing_graph_->set_seed_funcs_.size(); + for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) { + bool found = false; + for (auto node : sorted_nodes) { + if (visited.count(node) > 0) continue; + cudaGraphNodeType type; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type)); + if (type == cudaGraphNodeTypeKernel) { + cudaKernelNodeParams params; + auto err = cudaGraphKernelNodeGetParams(node, ¶ms); + if (err == cudaErrorInvalidDeviceFunction) { + continue; + } else { + PADDLE_ENFORCE_GPU_SUCCESS(err); + } + CUDAKernelParams kernel_params(¶ms); + if (set_seed_func(&kernel_params, true)) { + capturing_graph_->pre_hooks_.back().push_back( + [set_seed_func, node, params](cudaGraphExec_t exec_graph) { + CUDAKernelParams kernel_params(¶ms); + set_seed_func(&kernel_params, false); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams( + exec_graph, node, ¶ms)); + }); + visited.insert(node); + found = true; + break; + } + } + } + PADDLE_ENFORCE_EQ(found, true, + phi::errors::InvalidArgument( + "Cannot find the corresponding random CUDA kernel.")); + } + capturing_graph_->set_seed_funcs_.clear(); + cudaGraphExec_t exec_graph; PADDLE_ENFORCE_GPU_SUCCESS( cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0)); VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_ - << ", segment id " << capturing_graph_->graphs_.size(); + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; capturing_graph_->graphs_.emplace_back(graph); capturing_graph_->exec_graphs_.emplace_back(exec_graph); #endif diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h index ca1e7abb375cb65110fec3b73b6609519052a6bb..8f84f26adbdbccbecb72139b2e1be4d091037e54 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h @@ -32,6 +32,70 @@ namespace paddle { namespace platform { +template +static bool IsBitwiseEqual(const T &x, const T &y) { + return std::memcmp(&x, &y, sizeof(T)) == 0; +} + +class CUDAKernelParams { + public: + explicit CUDAKernelParams(const cudaKernelNodeParams *params) + : params_(params) {} + + const void *func() const { return params_->func; } + + template + T &As(size_t idx) const { + return *reinterpret_cast(params_->kernelParams[idx]); + } + + private: + const cudaKernelNodeParams *params_; +}; + +template +struct IsSameKernelHelper; + +template +struct IsSameKernelHelper { + private: + template + struct Impl { + static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { + using CompareT = typename std::tuple_element::type; + if (!IsBitwiseEqual(params.As(IDX), + std::get(args))) { + return false; + } + + constexpr auto NewIsEnd = + (IDX + 1 == sizeof(std::tuple_size::value)); + return Impl::Compare(params, args); + } + }; + + template + struct Impl { + static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { + return true; + } + }; + + public: + using FuncArgsTuple = decltype(std::make_tuple(std::declval()...)); + + template + static bool Compare(const CUDAKernelParams ¶ms, Args... args) { + constexpr auto kNumArgs = sizeof...(FuncArgs); + static_assert(kNumArgs == sizeof...(Args), "Argument number not match"); + + auto args_tuple = std::make_tuple(args...); + using TupleT = typename std::decay::type; + return Impl::Compare(params, args_tuple); + } +}; + #if CUDA_VERSION >= 10010 static void ThrowErrorIfNotSupportCUDAGraph() {} #else @@ -61,10 +125,35 @@ class CUDAGraph { } public: + static constexpr int64_t kDefaultPoolID = 0; + static constexpr int64_t kInvalidPoolID = -1; + ~CUDAGraph() { Reset(); } CUDAGraphID ID() const { return id_; } + static int64_t SetMemoryPoolID(int64_t pool_id) { + auto &pool_id_ = capturing_graph_->pool_id_; + PADDLE_ENFORCE_EQ( + pool_id_, kInvalidPoolID, + phi::errors::InvalidArgument("Cannot reset memory pool id twice, the " + "former memory pool id is %d.", + pool_id_)); + if (pool_id <= kInvalidPoolID) { + pool_id_ = UniqueMemoryPoolID(); + } else { + PADDLE_ENFORCE_GE( + pool_id, kDefaultPoolID, + phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id)); + pool_id_ = pool_id; + } + return pool_id_; + } + + int64_t PoolID() const { return pool_id_; } + + static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; } + void Replay(); void Reset(); @@ -120,12 +209,17 @@ class CUDAGraph { } } - private: - static CUDAGraphID UniqueID() { - static std::atomic id; - return id.fetch_add(1); + using SetSeedFunc = std::function; + static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) { + std::lock_guard guard(capturing_graph_->func_mtx_); + capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func)); } + static int64_t UniqueMemoryPoolID(); + + private: + static CUDAGraphID UniqueID(); + private: #if CUDA_VERSION >= 10010 std::vector graphs_; @@ -135,10 +229,17 @@ class CUDAGraph { cudaStream_t stream_{nullptr}; platform::CUDAPlace place_; CUDAGraphID id_; + int64_t pool_id_{kInvalidPoolID}; std::vector> callbacks_; bool is_reset_{false}; std::mutex mtx_; + std::vector set_seed_funcs_; + std::vector>> pre_hooks_; + std::mutex func_mtx_; + + bool is_first_run_{true}; + static paddle::optional capturing_thread_id_; static std::unique_ptr capturing_graph_; }; diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 99ec4212918dee49930e169c6edb27f8c6d9b10d..a3e996dbcbf6472ecbeb22593f563609275d50fe 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -27,7 +27,8 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args, GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); framework::AttributeMap attrs; - ConstructAttrMapFromPyArgs("run_program", args, 5, PyTuple_GET_SIZE(args), + // TODO(zengjinle): support CUDA Graph on eager mode + ConstructAttrMapFromPyArgs("run_program", args, 6, PyTuple_GET_SIZE(args), attrs); tstate = PyEval_SaveThread(); diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 0e9c08cff28599b631e1ceae166136b047cc5b14..8b9b98eba126e0367bb73c7b19a75bdbfa06d99f 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -640,10 +640,11 @@ void CastPyArg2AttrBlock(PyObject* obj, void ConstructAttrMapFromPyArgs( const std::string& op_type, PyObject* args, ssize_t attr_start, ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT - PADDLE_ENFORCE_EQ( - (attr_end - attr_start) % 2, 0, - platform::errors::InvalidArgument( - "The number of arguments for attributes should be even.")); + PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2, 0, + platform::errors::InvalidArgument( + "The number of arguments for attributes should be even " + "but attr_start = %d, attr_end = %d.", + attr_start, attr_end)); auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 972e8aafab758bdc04f5dd527c2b7ce3f1585e52..a6fd06f5d7059da9ed004124613e0060ccdfad1c 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -182,7 +182,7 @@ std::map> op_outs_map = { {"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, - {"run_program", {"DOut"}}, + {"run_program", {"DOut", "CUDAGraph"}}, {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -267,7 +267,7 @@ std::map> op_passing_outs_map = { {"moving_average_abs_max_scale", {"Out", "OutScale", "OutAccum", "OutState"}}, {"rnn", {"DropoutState"}}, - {"run_program", {"Out", "DOut", "OutScope"}}, + {"run_program", {"Out", "DOut", "OutScope", "CUDAGraph"}}, {"clear_float_status", {"FloatStatusOut"}}, {"get_float_status", {"FloatStatusOut"}}, {"assign", {"Out"}}, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0e1271c1fe07f5791e106217c1b7b5a659fe019b..d1c2b28dc80cf65a59f9869fbde434b4feb61095 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -604,6 +604,8 @@ PYBIND11_MODULE(core_noavx, m) { place, static_cast(mode)); }) .def_static("end_capture", &platform::EndCUDAGraphCapture) + .def_static("gen_new_memory_pool_id", + &platform::CUDAGraph::UniqueMemoryPoolID) .def("replay", &platform::CUDAGraph::Replay) .def("reset", &platform::CUDAGraph::Reset) .def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles); diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py index 29e1b2694a699272784e1107807373e922b465b1..e7987cf447ff5ce6cd9b4dadb19ef5c3a505b2de 100644 --- a/python/paddle/device/cuda/graphs.py +++ b/python/paddle/device/cuda/graphs.py @@ -17,15 +17,23 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA if is_compiled_with_cuda() and not is_compiled_with_rocm(): from paddle.fluid.core import CUDAGraph as CoreCUDAGraph + + def is_cuda_graph_supported(): + return True else: CoreCUDAGraph = None + def is_cuda_graph_supported(): + return False + + +ALL_MODES = ["global", "thread_local", "relaxed"] + class CUDAGraph: def __init__(self, place=None, mode="thread_local"): assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU." - ALL_MODES = ["global", "thread_local", "relaxed"] self._graph = None if place is None: device_id = int(os.environ.get('FLAGS_selected_gpus', 0)) @@ -55,3 +63,25 @@ class CUDAGraph: if flags is None: flags = 2047 # only all information. It can be any integer inside [1, 2048) self._graph.print_to_dot_files(dirname, flags) + + +def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"): + assert mode in ALL_MODES + from paddle.jit import to_static + from paddle.nn import Layer + new_function = to_static(function) + if isinstance(function, Layer): + mock_func = new_function.forward + else: + mock_func = new_function + mock_func._cuda_graph_capture_mode = mode + if memory_pool == "default": + mock_func._cuda_graph_pool_id = 0 + elif memory_pool == "new": + mock_func._cuda_graph_pool_id = CoreCUDAGraph.gen_new_memory_pool_id() + else: + if isinstance(memory_pool, Layer): + mock_func._cuda_graph_pool_id = memory_pool.forward._cuda_graph_pool_id + else: + mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id + return new_function diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 90f960798ef2ccdcd3c94f1d8fc3a10ef2972166..64652dd8e35899e91324c2e7679b562766a0256b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -148,6 +148,9 @@ class PartialProgramLayer: self._origin_main_program = self._verify_program(main_program) self._tmp_scope_vec = self._create_scope_vec() + self._cuda_graph_vec = self._create_cuda_graph_vec() + self._cuda_graph_capture_mode = "" + self._cuda_graph_pool_id = 0 # Set default mode to train self.training = True @@ -339,9 +342,15 @@ class PartialProgramLayer: def __call__(self, inputs): in_vars, out_vars = self._prepare(inputs) - attrs = ('global_block', self.program.desc.block(0), 'start_op_index', - 0, 'end_op_index', self._get_end_op_index(), 'is_test', - not self.training, 'program_id', self.program_id) + attrs = [ + 'global_block', self.program.desc.block(0), 'start_op_index', 0, + 'end_op_index', self._get_end_op_index(), 'is_test', + not self.training, 'program_id', self.program_id + ] + if self._cuda_graph_capture_mode: + attrs.extend( + ('cuda_graph_capture_mode', self._cuda_graph_capture_mode, + 'cuda_graph_pool_id', self._cuda_graph_pool_id)) self._cast_fp16_if_pure_fp16(in_vars) @@ -349,7 +358,7 @@ class PartialProgramLayer: self._valid_vars(in_vars), self._valid_vars(self._params), self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads, - *attrs) + self._cuda_graph_vec, *attrs) self.drop_scope_if_no_grad() restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) @@ -471,6 +480,12 @@ class PartialProgramLayer: tmp_scope_vec = [inner_scope] return tmp_scope_vec + def _create_cuda_graph_vec(self): + var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph", + core.VarDesc.VarType.RAW, True) + var.stop_gradient = True + return var + def _restore_out(self, out_vars): """ Restores same nested outputs by only replacing the Variable with VarBase. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index eac2941c09778a02712008de1ff306d728cf81c6..207cff67a1bc4f6f00b6459f135009a8ba88f0f5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -267,6 +267,8 @@ class StaticFunction(object): self._program_trans = ProgramTranslator() self._kwargs = kwargs self._training = True + self._cuda_graph_capture_mode = "" + self._cuda_graph_pool_id = 0 def train(self): if isinstance(self._class_instance, @@ -367,6 +369,9 @@ class StaticFunction(object): else: partial_program_layer.training = self._training + partial_program_layer._cuda_graph_capture_mode = self._cuda_graph_capture_mode + partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id + # 4. return outputs. try: return partial_program_layer(args) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index f10b65222021457d1e892430b69f42f32e0107bc..249c7b6a064258637c5daa5f006c13bd03e79839 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -874,7 +874,7 @@ def _run_dygraph(instance, input, program_holder): _valid_vars(input_vars), _valid_vars(persistable_vars), _valid_vars(output_vars), tmp_scope_vec, - _valid_vars(double_grad_vars), *attrs) + _valid_vars(double_grad_vars), None, *attrs) # NOTE: [ why need set param's gradient type here ] # if user set sparse gradient mode, the param's gradient # will be SelectedRows, not LoDTensor. But tracer will just diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..182a70af8a890fc1eccbb6477571bf213244b6d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import unittest +import numpy as np +from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported + + +class SimpleModel(nn.Layer): + def __init__(self, in_size, out_size): + super(SimpleModel, self).__init__() + self.linear = nn.Linear(in_size, out_size) + self.dropout_1 = paddle.nn.Dropout(0.1) + self.relu = nn.ReLU() + self.dropout_2 = paddle.nn.Dropout(0.5) + self.gelu = nn.GELU() + + def forward(self, x): + x = self.linear(x) + x = self.dropout_1(x) + x = self.relu(x) + x = self.dropout_2(x) + x = self.gelu(x) + return x + + +class TestSimpleModel(unittest.TestCase): + def setUp(self): + paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0}) + + def run_base(self, func, use_cuda_graph, memory_pool="default", seed=10): + paddle.seed(seed) + is_layer = isinstance(func, paddle.nn.Layer) + if use_cuda_graph: + func = wrap_cuda_graph(func, memory_pool=memory_pool) + + for _ in range(10): + x = paddle.randn([3, 10], dtype='float32') + x.stop_gradient = False + y = x * x + 100 + loss = func(y).mean() + loss.backward() + if is_layer: + func.clear_gradients() + + return func, x.grad.numpy() + + def check(self, func): + if not is_cuda_graph_supported(): + return + + _, value1 = self.run_base(func, False) + layer, value2 = self.run_base(func, True, "default") + _, value3 = self.run_base(func, True, "new") + _, value4 = self.run_base(func, True, layer) + self.assertTrue(np.array_equal(value1, value2)) + self.assertTrue(np.array_equal(value1, value3)) + self.assertTrue(np.array_equal(value1, value4)) + + def test_layer(self): + self.check(SimpleModel(10, 20)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_run_program.py b/python/paddle/fluid/tests/unittests/test_eager_run_program.py index 0253f9a21c6adb2391bf501d2651e385fa03ed1c..620f72ccb3073ca4a7995776d99725ab7caab328 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_run_program.py +++ b/python/paddle/fluid/tests/unittests/test_eager_run_program.py @@ -104,7 +104,7 @@ class TestRunProgram(unittest.TestCase): 'is_test', False, 'program_id', _hash_with_id(program)) _C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope], - [fake_var], *attrs) + [fake_var], None, *attrs) loss = paddle.mean(out_t) loss.backward() diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index 68f24bf25700848222843cb06852583d0f4db377..fdb931e25314ed55241ddeeeff8e38e1dba7f519 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -188,7 +188,7 @@ class RunProgramOpTest(unittest.TestCase): outputs = self.prepare_dygraph_output() _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], - outputs['OutScope'], outputs['DOut'], + outputs['OutScope'], outputs['DOut'], None, *self.attrs) return outputs['Out'] @@ -202,7 +202,7 @@ class RunProgramOpTest(unittest.TestCase): outputs = self.prepare_dygraph_output() _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], - outputs['OutScope'], outputs['DOut'], + outputs['OutScope'], outputs['DOut'], None, *self.attrs) for param in input_param_list: