未验证 提交 d05b940a 编写于 作者: S sneaxiy 提交者: GitHub

Support CUDA Graph for partial graph in dygraph mode (#42786)

* support CUDAGraph for partial graph

* add ut

* fix ci

* fix ut again because of eager mode

* fix kunlun ci

* fix win ci
上级 126248ac
......@@ -46,6 +46,12 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
strategy_.num_threads_ = 1;
traced_ops_.clear();
for (auto *op_node : TopologySortOperations(*graph_)) {
if (op_node->IsWrappedBy<OpHandleBase>()) {
traced_ops_.emplace_back(&(op_node->Wrapper<OpHandleBase>()));
}
}
}
pool_.reset(new ::ThreadPool(strategy.num_threads_));
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
......
......@@ -137,6 +137,31 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
return g_exe_cache_info_map;
}
static PEAndGraphPair CreateExecutorInfo(
const ProgramDesc &program_desc, const platform::Place &place,
int64_t start_op_index, int64_t end_op_index, framework::Scope *scope,
const details::BuildStrategy &build_strategy) {
auto execution_strategy = details::GetExecutionStrategy(place);
auto graph = std::make_shared<framework::ir::Graph>(
program_desc, start_op_index, end_op_index);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
place, scope, execution_strategy, build_strategy, graph.get());
parallel_executor->PrepareVariables(scope);
return std::make_pair(parallel_executor, graph);
}
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index,
int64_t end_op_index,
framework::Scope *scope) {
details::BuildStrategy build_strategy;
build_strategy.fix_op_run_order_ = true;
auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
end_op_index, scope, build_strategy);
return pe_and_graph;
}
CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index, int64_t end_op_index,
......@@ -153,21 +178,17 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
}
VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad;
auto execution_strategy = details::GetExecutionStrategy(place);
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);
// 2. Construct Graph and ParallelExecutor.
auto graph = std::make_shared<framework::ir::Graph>(
program_desc, start_op_index, end_op_index);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
place, scope, execution_strategy, build_strategy, graph.get());
parallel_executor->PrepareVariables(scope);
auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
end_op_index, scope, build_strategy);
// 3. Insert value into cached map.
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
cached_value.executor_ = parallel_executor;
cached_value.graph_ = std::move(graph);
return std::make_pair(parallel_executor, /*is_new_created=*/true);
cached_value.executor_ = pe_and_graph.first;
cached_value.graph_ = pe_and_graph.second;
return std::make_pair(pe_and_graph.first, /*is_new_created=*/true);
} else {
VLOG(1) << "get exe_info from cache by: " << program_id
<< " is_grad: " << is_grad;
......
......@@ -127,11 +127,20 @@ class ExecutorInfoCache {
using CacheInfo =
std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;
using PEAndGraphPair =
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
const platform::Place& place,
int64_t start_op_index, int64_t end_op_index,
bool is_grad, int64_t program_id,
framework::Scope* scope);
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc,
const platform::Place& place,
int64_t start_op_index,
int64_t end_op_index,
framework::Scope* scope);
} // namespace framework
} // namespace paddle
......@@ -41,6 +41,8 @@
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif
#include "paddle/fluid/operators/cuda_graph_with_in_out.h"
namespace paddle {
namespace framework {
......
......@@ -87,6 +87,8 @@ namespace operators {
class CudnnRNNCache;
class CUDAGraphWithInOuts;
namespace reader {
class LoDTensorBlockingQueueHolder;
class OrderedMultiDeviceLoDTensorBlockingQueueHolder;
......@@ -189,7 +191,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#if defined(PADDLE_WITH_CNCL)
cnclCliqueId,
#endif
int, float, Vocab>;
std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>, int, float,
Vocab>;
template <typename T>
struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
......
......@@ -123,6 +123,8 @@ class CUDAGraphAllocator
: underlying_allocator_(allocator) {}
public:
~CUDAGraphAllocator() { VLOG(10) << "CUDAGraphAllocator destructed"; }
static std::shared_ptr<Allocator> Create(
const std::shared_ptr<Allocator>& allocator) {
return std::shared_ptr<Allocator>(new CUDAGraphAllocator(allocator));
......@@ -973,7 +975,7 @@ AllocatorFacade& AllocatorFacade::Instance() {
AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
auto id = platform::CUDAGraph::CapturingID();
auto id = platform::CUDAGraph::CapturingPoolID();
auto iter = cuda_graph_map_.find(id);
PADDLE_ENFORCE_NE(
iter, cuda_graph_map_.end(),
......@@ -1116,7 +1118,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place,
}
#ifdef PADDLE_WITH_CUDA
void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) {
PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth,
platform::errors::InvalidArgument(
"CUDA Graph is only supported when the "
......@@ -1124,23 +1126,32 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
"FLAGS_allocator_strategy=\"%s\"",
FLAGS_allocator_strategy));
auto& allocator = cuda_graph_map_[id];
PADDLE_ENFORCE_EQ(
allocator.get(), nullptr,
platform::errors::InvalidArgument(
"The memory pool of the CUDA Graph with ID %d have been prepared.",
id));
allocator.reset(new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false));
VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id;
auto& ref_cnt = cuda_graph_ref_cnt_[id];
if (allocator.get() == nullptr) {
allocator.reset(
new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false));
VLOG(10) << "Create memory pool for CUDA Graph with memory ID " << id;
} else {
VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id;
}
++ref_cnt;
}
void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) {
auto iter = cuda_graph_map_.find(id);
PADDLE_ENFORCE_NE(iter, cuda_graph_map_.end(),
void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) {
auto ref_cnt_iter = cuda_graph_ref_cnt_.find(id);
PADDLE_ENFORCE_NE(ref_cnt_iter, cuda_graph_ref_cnt_.end(),
platform::errors::InvalidArgument(
"Cannot find CUDA Graph with ID = %d", id));
cuda_graph_map_.erase(iter);
VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id;
"Cannot find CUDA Graph with memory ID = %d", id));
auto& ref_cnt = ref_cnt_iter->second;
--ref_cnt;
if (ref_cnt == 0) {
cuda_graph_map_.erase(id);
cuda_graph_ref_cnt_.erase(ref_cnt_iter);
VLOG(10) << "Remove memory pool of CUDA Graph with memory ID " << id;
} else {
VLOG(10) << "Decrease memory pool ID " << id << " reference count to be "
<< ref_cnt;
}
}
#endif
#endif
......
......@@ -89,8 +89,8 @@ class AllocatorFacade {
#endif
#ifdef PADDLE_WITH_CUDA
void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id);
void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id);
void PrepareMemoryPoolForCUDAGraph(int64_t id);
void RemoveMemoryPoolOfCUDAGraph(int64_t id);
#endif
// TODO(yy): Allocate a Copy-On-Write allocation?
......@@ -98,8 +98,9 @@ class AllocatorFacade {
AllocatorFacade();
AllocatorFacadePrivate* m_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<CUDAGraphID, std::unique_ptr<AllocatorFacadePrivate>>
std::unordered_map<int64_t, std::unique_ptr<AllocatorFacadePrivate>>
cuda_graph_map_;
std::unordered_map<int64_t, int64_t> cuda_graph_ref_cnt_;
#endif
};
......
......@@ -107,6 +107,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS cast_kernel)
op_library(save_combine_op DEPS string_array)
op_library(load_combine_op DEPS string_array)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#endif
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_CUDA
class CUDAGraphWithInOuts {
public:
template <typename Callable>
CUDAGraphWithInOuts(Callable &&callable, platform::CUDAPlace place,
const std::vector<const framework::Tensor *> &in_ptrs,
cudaStreamCaptureMode mode, int64_t pool_id) {
in_indices_.resize(in_ptrs.size());
ins_.reserve(in_ptrs.size());
int64_t valid_in_idx = 0;
for (size_t i = 0; i < in_ptrs.size(); ++i) {
if (in_ptrs[i] == nullptr) {
in_indices_[i] = -1;
} else {
in_indices_[i] = (valid_in_idx++);
ins_.push_back(*in_ptrs[i]);
}
}
platform::BeginCUDAGraphCapture(place, mode, pool_id);
auto out_ptrs = callable(in_ptrs);
graph_ = platform::EndCUDAGraphCapture();
graph_->Replay();
out_indices_.resize(out_ptrs.size());
outs_.reserve(out_ptrs.size());
int64_t valid_out_idx = 0;
for (size_t i = 0; i < out_ptrs.size(); ++i) {
if (out_ptrs[i] == nullptr) {
out_indices_[i] = -1;
} else {
out_indices_[i] = (valid_out_idx++);
outs_.push_back(*out_ptrs[i]);
}
}
}
void Run(const std::vector<const framework::Tensor *> &ins) {
PADDLE_ENFORCE_EQ(
ins.size(), in_indices_.size(),
phi::errors::InvalidArgument("The input number does not match."));
for (size_t i = 0; i < in_indices_.size(); ++i) {
if (in_indices_[i] >= 0) {
auto *dst = &ins_[in_indices_[i]];
framework::TensorCopy(*ins[i], dst->place(), dst);
}
}
graph_->Replay();
}
std::vector<framework::Tensor *> GetOutputs() {
std::vector<framework::Tensor *> outs(out_indices_.size());
for (size_t i = 0; i < out_indices_.size(); ++i) {
if (out_indices_[i] >= 0) {
outs[i] = &outs_[out_indices_[i]];
}
}
return outs;
}
int64_t PoolID() const { return graph_->PoolID(); }
private:
std::unique_ptr<platform::CUDAGraph> graph_;
std::vector<framework::Tensor> ins_;
std::vector<framework::Tensor> outs_;
std::vector<int64_t> in_indices_;
std::vector<int64_t> out_indices_;
};
template <typename Callable>
static std::unique_ptr<CUDAGraphWithInOuts> CaptureCUDAGraph(
Callable &&callable, const framework::ExecutionContext &ctx,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, cudaStreamCaptureMode mode,
int64_t pool_id) {
std::vector<const framework::Tensor *> inputs;
for (const auto &name : input_names) {
auto input_tensors = ctx.MultiInput<framework::Tensor>(name);
inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end());
}
auto func = [&](const std::vector<const framework::Tensor *> &inputs) {
callable(ctx);
std::vector<framework::Tensor *> outputs;
for (const auto &name : output_names) {
auto output_tensors = ctx.MultiOutput<framework::Tensor>(name);
outputs.insert(outputs.end(), output_tensors.begin(),
output_tensors.end());
}
return outputs;
};
return std::make_unique<CUDAGraphWithInOuts>(func, ctx.GetPlace(), inputs,
mode, pool_id);
}
static void ExecuteCUDAGraph(const framework::ExecutionContext &ctx,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
CUDAGraphWithInOuts *graph) {
std::vector<const framework::Tensor *> inputs;
for (const auto &name : input_names) {
auto input_tensors = ctx.MultiInput<framework::Tensor>(name);
inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end());
}
graph->Run(inputs);
auto outputs = graph->GetOutputs();
size_t idx = 0;
for (const auto &name : output_names) {
auto output_tensors = ctx.MultiOutput<framework::Tensor>(name);
for (auto *out_t : output_tensors) {
if (outputs[idx] != nullptr) {
*out_t = *outputs[idx];
} else {
PADDLE_ENFORCE_EQ(
out_t, nullptr,
phi::errors::InvalidArgument(
"The %d-th output variable should be nullptr.", idx));
}
++idx;
}
}
}
#else
class CUDAGraphWithInOuts {};
#endif
} // namespace operators
} // namespace paddle
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
......@@ -195,9 +196,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize);
VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment, main_offset);
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
!is_fix_seed, (VectorizedRandomGenerator<T, uint8_t>), grid_size,
block_size, 0, stream, offset, KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment, main_offset);
} else {
if (upscale_in_train) {
// todo: can y share with data with x directly?
......
......@@ -90,6 +90,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"computes double grad.")
.AsDuplicable()
.AsDispensable();
AddOutput("CUDAGraph", "The output CUDA Graph when use_cuda_graph=True.")
.AsDispensable();
AddAttr<BlockDesc*>("global_block",
"(BlockDesc *)"
"The global block of executed program desc.");
......@@ -107,6 +109,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"program_id",
"(int64_t)"
"The unique hash id used as cache key for ExecutorInfoCache.");
AddAttr<std::string>("cuda_graph_capture_mode",
"(str, default '') The CUDA Graph capture mode. "
"Default '' means no CUDA Graph capturing.")
.SetDefault("");
AddAttr<int64_t>("cuda_graph_pool_id",
"(int64_t, default 0) The CUDA Graph memory pool ID.")
.SetDefault(0);
AddComment(R"DOC(
RunProgram operator.
......@@ -191,6 +200,9 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetInput("DOut", this->Output("DOut"));
if (this->HasOutput("CUDAGraph")) {
grad_op->SetInput("CUDAGraph", this->Output("CUDAGraph"));
}
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
auto block_desc =
......
......@@ -34,6 +34,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cuda_graph_with_in_out.h"
#endif
DECLARE_bool(use_mkldnn);
......@@ -167,13 +170,84 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
}
}
#ifdef PADDLE_WITH_CUDA
static cudaStreamCaptureMode StringToCUDAGraphCaptureMode(
const std::string &mode) {
if (mode == "global") {
return cudaStreamCaptureModeGlobal;
} else if (mode == "thread_local") {
return cudaStreamCaptureModeThreadLocal;
} else if (mode == "relaxed") {
return cudaStreamCaptureModeRelaxed;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported CUDA Graph capture mode %s", mode));
}
}
#endif
} // namespace details
template <typename DeviceContext, typename T>
class RunProgramOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto &capture_mode = ctx.Attr<std::string>("cuda_graph_capture_mode");
auto is_test = ctx.Attr<bool>("is_test");
if (capture_mode.empty()) {
ComputeImpl(ctx, is_test, false);
return;
}
#ifdef PADDLE_WITH_CUDA
auto mode = details::StringToCUDAGraphCaptureMode(capture_mode);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
auto *graph_var = ctx.OutputVar("CUDAGraph");
PADDLE_ENFORCE_NOT_NULL(
graph_var,
phi::errors::InvalidArgument("Output(CUDAGraph) must exist when "
"cuda_graph_capture_mode is valid."));
using GraphVecType = std::vector<std::unique_ptr<CUDAGraphWithInOuts>>;
auto &inner_graphs = *(graph_var->GetMutable<GraphVecType>());
inner_graphs.resize(std::max<size_t>(3, inner_graphs.size()));
size_t graph_idx = is_test ? 0 : 1;
if (inner_graphs[graph_idx].get() == nullptr) {
int64_t pool_id;
if (inner_graphs[1 - graph_idx].get() != nullptr) {
pool_id = inner_graphs[1 - graph_idx]->PoolID();
} else {
pool_id = ctx.Attr<int64_t>("cuda_graph_pool_id");
}
framework::PEAndGraphPair pe_and_graph;
auto callable = [this, is_test, &pe_and_graph](
const framework::ExecutionContext &exe_ctx) {
pe_and_graph = ComputeImpl(exe_ctx, is_test, true);
};
inner_graphs[graph_idx] = CaptureCUDAGraph(
callable, ctx, {"X"}, {"Out", "DOut"}, mode, pool_id);
VLOG(10) << "Capture Forward CUDA Graph";
} else {
VLOG(10) << "Run Forward CUDA Graph directly";
ExecuteCUDAGraph(ctx, {"X"}, {"Out", "DOut"},
inner_graphs[graph_idx].get());
}
#else
PADDLE_THROW(
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
#endif
}
private:
framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx,
bool is_test,
bool use_cuda_graph) const {
VLOG(2) << "RunProgramOpKernel Compute";
framework::PEAndGraphPair pe_and_graph;
// Step 1. prepare inputs, outputs, attrs
auto &input_vars = ctx.MultiInputVar("X");
auto &param_vars = ctx.MultiInputVar("Params");
......@@ -192,7 +266,6 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto start_op_index = ctx.Attr<int64_t>("start_op_index");
auto end_op_index = ctx.Attr<int64_t>("end_op_index");
auto is_test = ctx.Attr<bool>("is_test");
auto program_id = ctx.Attr<int64_t>("program_id");
// NOTE(chenweihang): In order not to add new variable type, use vector
......@@ -223,15 +296,29 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
if (end_op_index > start_op_index) {
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad=*/false, program_id, &scope);
auto &parallel_executor = cache_info.first;
bool is_new_created;
if (use_cuda_graph) {
pe_and_graph = framework::CreateFixOrderExecutorInfo(
*program, ctx.GetPlace(), start_op_index, end_op_index, &scope);
is_new_created = true;
} else {
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad=*/false, program_id, &scope);
pe_and_graph.first = cache_info.first;
is_new_created = cache_info.second;
}
auto &parallel_executor = pe_and_graph.first;
// all out_vars are skip_eager_var
std::vector<std::string> tmp_vars;
auto &skip_eager_delete_vars =
framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, false);
if (cache_info.second /*is_new_created*/) {
use_cuda_graph
? tmp_vars
: framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, false);
if (is_new_created) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names);
skip_eager_delete_vars.insert(skip_eager_delete_vars.end(),
output_var_names.begin(),
......@@ -263,6 +350,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace());
#endif
return pe_and_graph;
}
};
......@@ -270,14 +358,68 @@ template <typename DeviceContext, typename T>
class RunProgramGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto &capture_mode = ctx.Attr<std::string>("cuda_graph_capture_mode");
if (capture_mode.empty()) {
ComputeImpl(ctx, false);
return;
}
#ifdef PADDLE_WITH_CUDA
auto mode = details::StringToCUDAGraphCaptureMode(capture_mode);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
auto *graph_var =
const_cast<framework::Variable *>(ctx.InputVar("CUDAGraph"));
PADDLE_ENFORCE_NOT_NULL(
graph_var,
phi::errors::InvalidArgument("Output(CUDAGraph) must exist when "
"cuda_graph_capture_mode is valid."));
auto &inner_graphs = *(
graph_var
->GetMutable<std::vector<std::unique_ptr<CUDAGraphWithInOuts>>>());
const size_t graph_idx = 2;
if (inner_graphs[graph_idx].get() == nullptr) {
framework::PEAndGraphPair pe_and_graph;
auto callable =
[this, &pe_and_graph](const framework::ExecutionContext &exe_ctx) {
pe_and_graph = ComputeImpl(exe_ctx, true);
};
int64_t pool_id = inner_graphs[0].get() != nullptr
? inner_graphs[0]->PoolID()
: inner_graphs[1]->PoolID();
inner_graphs[graph_idx] =
CaptureCUDAGraph(callable, ctx, {framework::GradVarName("Out")},
{framework::GradVarName("X")}, mode, pool_id);
VLOG(10) << "Capture Backward CUDA Graph";
} else {
ExecuteCUDAGraph(ctx, {framework::GradVarName("Out")},
{framework::GradVarName("X")},
inner_graphs[graph_idx].get());
VLOG(10) << "Run Backward CUDA Graph directly";
}
#else
PADDLE_THROW(
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
#endif
}
private:
framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx,
bool use_cuda_graph) const {
VLOG(2) << "RunProgramGradOpKernel Compute";
framework::PEAndGraphPair pe_and_graph;
// Step 1. prepare inputs and outputs
auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out"));
auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X"));
auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params"));
// if all output vars are set to stop_gradient, grad op no need to executed
if (input_grad_vars.empty() && param_grad_vars.empty()) return;
if (input_grad_vars.empty() && param_grad_vars.empty()) {
return pe_and_graph;
}
auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out"));
// NOTE: after PR22939 [Add double grad] merged, the grad op maker's
......@@ -321,15 +463,27 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
if (end_op_index > start_op_index) {
// Step 2. prepare executor and scope
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true, program_id, &scope);
auto &parallel_executor = cache_info.first;
bool is_new_created;
if (use_cuda_graph) {
pe_and_graph = framework::CreateFixOrderExecutorInfo(
*program, ctx.GetPlace(), start_op_index, end_op_index, &scope);
is_new_created = true;
} else {
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true, program_id, &scope);
pe_and_graph.first = cache_info.first;
is_new_created = cache_info.second;
}
auto &parallel_executor = pe_and_graph.first;
std::vector<std::string> tmp_vars;
auto &skip_eager_delete_vars =
framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, true);
if (cache_info.second /*is_new_created*/) {
use_cuda_graph
? tmp_vars
: framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, true);
if (is_new_created) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0,
output_grad_var_names);
......@@ -360,6 +514,7 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
global_inner_scope->DeleteScope(&scope);
VLOG(2) << "The number of sub scopes after backward: "
<< global_inner_scope->kids().size();
return pe_and_graph;
}
};
......
......@@ -16,23 +16,33 @@
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode) {
cudaStreamCaptureMode mode, int64_t pool_id) {
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
auto id = CUDAGraph::CapturingID();
auto old_value = FLAGS_use_stream_safe_cuda_allocator;
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = false;
}
pool_id = CUDAGraph::SetMemoryPoolID(pool_id);
memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph(
id);
AddResetCallbackIfCapturingCUDAGraph([id] {
pool_id);
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true;
}
AddResetCallbackIfCapturingCUDAGraph([pool_id] {
memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph(
id);
pool_id);
});
}
......
......@@ -23,10 +23,53 @@
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \
__block, __sm_size, __stream, \
__seed_inc, __seed_expr, \
__offset_expr, ...) \
do { \
if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \
using __Helper = \
::paddle::platform::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = \
::paddle::platform::DeviceContextPool::Instance().GetByPlace( \
::paddle::platform::CUDAGraph::CapturingPlace()); \
auto __set_seed_func = [=]( \
::paddle::platform::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::paddle::operators::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \
__block, __sm_size, __stream, \
__seed_inc, __seed_expr, \
__offset_expr, ...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif
// NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode);
cudaStreamCaptureMode mode,
int64_t pool_id = CUDAGraph::kInvalidPoolID);
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture();
#endif
......
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include <queue>
#include <unordered_map>
#include <unordered_set>
namespace paddle {
namespace platform {
......@@ -20,6 +23,69 @@ namespace platform {
std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none};
static std::vector<cudaGraphNode_t> ToposortCUDAGraph(cudaGraph_t graph) {
size_t num_nodes;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
std::vector<cudaGraphNode_t> nodes(num_nodes);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
size_t num_edges;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetEdges(graph, nullptr, nullptr, &num_edges));
std::vector<cudaGraphNode_t> from(num_edges), to(num_edges);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetEdges(graph, from.data(), to.data(), &num_edges));
std::unordered_map<cudaGraphNode_t, std::unordered_set<cudaGraphNode_t>>
in_edges, out_edges;
for (auto node : nodes) {
in_edges[node];
out_edges[node];
}
for (size_t i = 0; i < num_edges; ++i) {
in_edges[to[i]].insert(from[i]);
out_edges[from[i]].insert(to[i]);
}
std::queue<cudaGraphNode_t> q;
for (const auto &pair : in_edges) {
if (pair.second.empty()) {
q.push(pair.first);
}
}
nodes.clear();
while (!q.empty()) {
auto cur = q.front();
q.pop();
nodes.push_back(cur);
for (auto out_node : out_edges.at(cur)) {
auto &in_nodes = in_edges.at(out_node);
in_nodes.erase(cur);
if (in_nodes.empty()) {
q.push(out_node);
}
}
}
PADDLE_ENFORCE_EQ(
nodes.size(), num_nodes,
phi::errors::InvalidArgument("Toposort error, this may be a bug."));
return nodes;
}
CUDAGraphID CUDAGraph::UniqueID() {
static std::atomic<CUDAGraphID> id;
return id.fetch_add(1);
}
int64_t CUDAGraph::UniqueMemoryPoolID() {
static std::atomic<int64_t> id(CUDAGraph::kDefaultPoolID + 1);
return id.fetch_add(1);
}
void CUDAGraph::Reset() {
if (is_reset_) return;
#if CUDA_VERSION >= 10010
......@@ -46,9 +112,16 @@ void CUDAGraph::Replay() {
PADDLE_ENFORCE_EQ(is_reset_, false,
errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called."));
for (auto exec_graph : exec_graphs_) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graph, stream_));
size_t n = exec_graphs_.size();
for (size_t i = 0; i < n; ++i) {
if (!is_first_run_) {
for (auto &hook : pre_hooks_[i]) {
hook(exec_graphs_[i]);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graphs_[i], stream_));
}
is_first_run_ = false;
#endif
}
......@@ -72,7 +145,8 @@ void CUDAGraph::BeginSegmentCapture() {
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
<< ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
#endif
}
......@@ -112,15 +186,57 @@ void CUDAGraph::EndSegmentCapture() {
if (num_nodes == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph));
VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
<< ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
return;
}
auto sorted_nodes = ToposortCUDAGraph(graph);
capturing_graph_->pre_hooks_.emplace_back();
std::unordered_set<cudaGraphNode_t> visited;
VLOG(10) << "SetSeedFunc number : "
<< capturing_graph_->set_seed_funcs_.size();
for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) {
bool found = false;
for (auto node : sorted_nodes) {
if (visited.count(node) > 0) continue;
cudaGraphNodeType type;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeKernel) {
cudaKernelNodeParams params;
auto err = cudaGraphKernelNodeGetParams(node, &params);
if (err == cudaErrorInvalidDeviceFunction) {
continue;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(err);
}
CUDAKernelParams kernel_params(&params);
if (set_seed_func(&kernel_params, true)) {
capturing_graph_->pre_hooks_.back().push_back(
[set_seed_func, node, params](cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(&params);
set_seed_func(&kernel_params, false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams(
exec_graph, node, &params));
});
visited.insert(node);
found = true;
break;
}
}
}
PADDLE_ENFORCE_EQ(found, true,
phi::errors::InvalidArgument(
"Cannot find the corresponding random CUDA kernel."));
}
capturing_graph_->set_seed_funcs_.clear();
cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
<< ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
capturing_graph_->graphs_.emplace_back(graph);
capturing_graph_->exec_graphs_.emplace_back(exec_graph);
#endif
......
......@@ -32,6 +32,70 @@
namespace paddle {
namespace platform {
template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}
class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}
const void *func() const { return params_->func; }
template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
}
private:
const cudaKernelNodeParams *params_;
};
template <typename F, F f>
struct IsSameKernelHelper;
template <typename Return, typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, TupleT>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}
constexpr auto NewIsEnd =
(IDX + 1 == sizeof(std::tuple_size<TupleT>::value));
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};
template <typename TupleT, size_t IDX>
struct Impl<TupleT, IDX, true> {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
return true;
}
};
public:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
static_assert(kNumArgs == sizeof...(Args), "Argument number not match");
auto args_tuple = std::make_tuple(args...);
using TupleT = typename std::decay<decltype(args_tuple)>::type;
return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
}
};
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
......@@ -61,10 +125,35 @@ class CUDAGraph {
}
public:
static constexpr int64_t kDefaultPoolID = 0;
static constexpr int64_t kInvalidPoolID = -1;
~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; }
static int64_t SetMemoryPoolID(int64_t pool_id) {
auto &pool_id_ = capturing_graph_->pool_id_;
PADDLE_ENFORCE_EQ(
pool_id_, kInvalidPoolID,
phi::errors::InvalidArgument("Cannot reset memory pool id twice, the "
"former memory pool id is %d.",
pool_id_));
if (pool_id <= kInvalidPoolID) {
pool_id_ = UniqueMemoryPoolID();
} else {
PADDLE_ENFORCE_GE(
pool_id, kDefaultPoolID,
phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id));
pool_id_ = pool_id;
}
return pool_id_;
}
int64_t PoolID() const { return pool_id_; }
static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }
void Replay();
void Reset();
......@@ -120,12 +209,17 @@ class CUDAGraph {
}
}
private:
static CUDAGraphID UniqueID() {
static std::atomic<CUDAGraphID> id;
return id.fetch_add(1);
using SetSeedFunc = std::function<bool(CUDAKernelParams *, bool)>;
static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
}
static int64_t UniqueMemoryPoolID();
private:
static CUDAGraphID UniqueID();
private:
#if CUDA_VERSION >= 10010
std::vector<cudaGraph_t> graphs_;
......@@ -135,10 +229,17 @@ class CUDAGraph {
cudaStream_t stream_{nullptr};
platform::CUDAPlace place_;
CUDAGraphID id_;
int64_t pool_id_{kInvalidPoolID};
std::vector<std::function<void()>> callbacks_;
bool is_reset_{false};
std::mutex mtx_;
std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
std::mutex func_mtx_;
bool is_first_run_{true};
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
......
......@@ -27,7 +27,8 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args,
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("run_program", args, 5, PyTuple_GET_SIZE(args),
// TODO(zengjinle): support CUDA Graph on eager mode
ConstructAttrMapFromPyArgs("run_program", args, 6, PyTuple_GET_SIZE(args),
attrs);
tstate = PyEval_SaveThread();
......
......@@ -640,10 +640,11 @@ void CastPyArg2AttrBlock(PyObject* obj,
void ConstructAttrMapFromPyArgs(
const std::string& op_type, PyObject* args, ssize_t attr_start,
ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ(
(attr_end - attr_start) % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for attributes should be even."));
PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for attributes should be even "
"but attr_start = %d, attr_end = %d.",
attr_start, attr_end));
auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);
......
......@@ -182,7 +182,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"run_program", {"DOut"}},
{"run_program", {"DOut", "CUDAGraph"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
......@@ -267,7 +267,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
{"run_program", {"Out", "DOut", "OutScope", "CUDAGraph"}},
{"clear_float_status", {"FloatStatusOut"}},
{"get_float_status", {"FloatStatusOut"}},
{"assign", {"Out"}},
......
......@@ -604,6 +604,8 @@ PYBIND11_MODULE(core_noavx, m) {
place, static_cast<cudaStreamCaptureMode>(mode));
})
.def_static("end_capture", &platform::EndCUDAGraphCapture)
.def_static("gen_new_memory_pool_id",
&platform::CUDAGraph::UniqueMemoryPoolID)
.def("replay", &platform::CUDAGraph::Replay)
.def("reset", &platform::CUDAGraph::Reset)
.def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles);
......
......@@ -17,15 +17,23 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA
if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
def is_cuda_graph_supported():
return True
else:
CoreCUDAGraph = None
def is_cuda_graph_supported():
return False
ALL_MODES = ["global", "thread_local", "relaxed"]
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
......@@ -55,3 +63,25 @@ class CUDAGraph:
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
assert mode in ALL_MODES
from paddle.jit import to_static
from paddle.nn import Layer
new_function = to_static(function)
if isinstance(function, Layer):
mock_func = new_function.forward
else:
mock_func = new_function
mock_func._cuda_graph_capture_mode = mode
if memory_pool == "default":
mock_func._cuda_graph_pool_id = 0
elif memory_pool == "new":
mock_func._cuda_graph_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
else:
if isinstance(memory_pool, Layer):
mock_func._cuda_graph_pool_id = memory_pool.forward._cuda_graph_pool_id
else:
mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id
return new_function
......@@ -148,6 +148,9 @@ class PartialProgramLayer:
self._origin_main_program = self._verify_program(main_program)
self._tmp_scope_vec = self._create_scope_vec()
self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
# Set default mode to train
self.training = True
......@@ -339,9 +342,15 @@ class PartialProgramLayer:
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id)
attrs = [
'global_block', self.program.desc.block(0), 'start_op_index', 0,
'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id
]
if self._cuda_graph_capture_mode:
attrs.extend(
('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
'cuda_graph_pool_id', self._cuda_graph_pool_id))
self._cast_fp16_if_pure_fp16(in_vars)
......@@ -349,7 +358,7 @@ class PartialProgramLayer:
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
*attrs)
self._cuda_graph_vec, *attrs)
self.drop_scope_if_no_grad()
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......@@ -471,6 +480,12 @@ class PartialProgramLayer:
tmp_scope_vec = [inner_scope]
return tmp_scope_vec
def _create_cuda_graph_vec(self):
var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
core.VarDesc.VarType.RAW, True)
var.stop_gradient = True
return var
def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with VarBase.
......
......@@ -267,6 +267,8 @@ class StaticFunction(object):
self._program_trans = ProgramTranslator()
self._kwargs = kwargs
self._training = True
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
def train(self):
if isinstance(self._class_instance,
......@@ -367,6 +369,9 @@ class StaticFunction(object):
else:
partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = self._cuda_graph_capture_mode
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 4. return outputs.
try:
return partial_program_layer(args)
......
......@@ -874,7 +874,7 @@ def _run_dygraph(instance, input, program_holder):
_valid_vars(input_vars),
_valid_vars(persistable_vars),
_valid_vars(output_vars), tmp_scope_vec,
_valid_vars(double_grad_vars), *attrs)
_valid_vars(double_grad_vars), None, *attrs)
# NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import unittest
import numpy as np
from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported
class SimpleModel(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.dropout_1 = paddle.nn.Dropout(0.1)
self.relu = nn.ReLU()
self.dropout_2 = paddle.nn.Dropout(0.5)
self.gelu = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.dropout_1(x)
x = self.relu(x)
x = self.dropout_2(x)
x = self.gelu(x)
return x
class TestSimpleModel(unittest.TestCase):
def setUp(self):
paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0})
def run_base(self, func, use_cuda_graph, memory_pool="default", seed=10):
paddle.seed(seed)
is_layer = isinstance(func, paddle.nn.Layer)
if use_cuda_graph:
func = wrap_cuda_graph(func, memory_pool=memory_pool)
for _ in range(10):
x = paddle.randn([3, 10], dtype='float32')
x.stop_gradient = False
y = x * x + 100
loss = func(y).mean()
loss.backward()
if is_layer:
func.clear_gradients()
return func, x.grad.numpy()
def check(self, func):
if not is_cuda_graph_supported():
return
_, value1 = self.run_base(func, False)
layer, value2 = self.run_base(func, True, "default")
_, value3 = self.run_base(func, True, "new")
_, value4 = self.run_base(func, True, layer)
self.assertTrue(np.array_equal(value1, value2))
self.assertTrue(np.array_equal(value1, value3))
self.assertTrue(np.array_equal(value1, value4))
def test_layer(self):
self.check(SimpleModel(10, 20))
if __name__ == "__main__":
unittest.main()
......@@ -104,7 +104,7 @@ class TestRunProgram(unittest.TestCase):
'is_test', False, 'program_id', _hash_with_id(program))
_C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope],
[fake_var], *attrs)
[fake_var], None, *attrs)
loss = paddle.mean(out_t)
loss.backward()
......
......@@ -188,7 +188,7 @@ class RunProgramOpTest(unittest.TestCase):
outputs = self.prepare_dygraph_output()
_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'],
outputs['OutScope'], outputs['DOut'],
outputs['OutScope'], outputs['DOut'], None,
*self.attrs)
return outputs['Out']
......@@ -202,7 +202,7 @@ class RunProgramOpTest(unittest.TestCase):
outputs = self.prepare_dygraph_output()
_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'],
outputs['OutScope'], outputs['DOut'],
outputs['OutScope'], outputs['DOut'], None,
*self.attrs)
for param in input_param_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册