未验证 提交 d05b940a 编写于 作者: S sneaxiy 提交者: GitHub

Support CUDA Graph for partial graph in dygraph mode (#42786)

* support CUDAGraph for partial graph

* add ut

* fix ci

* fix ut again because of eager mode

* fix kunlun ci

* fix win ci
上级 126248ac
...@@ -46,6 +46,12 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -46,6 +46,12 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
VLOG(10) VLOG(10)
<< "Change thread number to 1 because the toposort order is unique"; << "Change thread number to 1 because the toposort order is unique";
strategy_.num_threads_ = 1; strategy_.num_threads_ = 1;
traced_ops_.clear();
for (auto *op_node : TopologySortOperations(*graph_)) {
if (op_node->IsWrappedBy<OpHandleBase>()) {
traced_ops_.emplace_back(&(op_node->Wrapper<OpHandleBase>()));
}
}
} }
pool_.reset(new ::ThreadPool(strategy.num_threads_)); pool_.reset(new ::ThreadPool(strategy.num_threads_));
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
......
...@@ -137,6 +137,31 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() { ...@@ -137,6 +137,31 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
return g_exe_cache_info_map; return g_exe_cache_info_map;
} }
static PEAndGraphPair CreateExecutorInfo(
const ProgramDesc &program_desc, const platform::Place &place,
int64_t start_op_index, int64_t end_op_index, framework::Scope *scope,
const details::BuildStrategy &build_strategy) {
auto execution_strategy = details::GetExecutionStrategy(place);
auto graph = std::make_shared<framework::ir::Graph>(
program_desc, start_op_index, end_op_index);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
place, scope, execution_strategy, build_strategy, graph.get());
parallel_executor->PrepareVariables(scope);
return std::make_pair(parallel_executor, graph);
}
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index,
int64_t end_op_index,
framework::Scope *scope) {
details::BuildStrategy build_strategy;
build_strategy.fix_op_run_order_ = true;
auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
end_op_index, scope, build_strategy);
return pe_and_graph;
}
CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc, CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
const platform::Place &place, const platform::Place &place,
int64_t start_op_index, int64_t end_op_index, int64_t start_op_index, int64_t end_op_index,
...@@ -153,21 +178,17 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc, ...@@ -153,21 +178,17 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
} }
VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad; VLOG(1) << "create exe_info for " << program_id << " is_grad: " << is_grad;
auto execution_strategy = details::GetExecutionStrategy(place);
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id); auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);
// 2. Construct Graph and ParallelExecutor. // 2. Construct Graph and ParallelExecutor.
auto graph = std::make_shared<framework::ir::Graph>( auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
program_desc, start_op_index, end_op_index); end_op_index, scope, build_strategy);
auto parallel_executor = std::make_shared<framework::ParallelExecutor>(
place, scope, execution_strategy, build_strategy, graph.get());
parallel_executor->PrepareVariables(scope);
// 3. Insert value into cached map. // 3. Insert value into cached map.
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad); auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
cached_value.executor_ = parallel_executor; cached_value.executor_ = pe_and_graph.first;
cached_value.graph_ = std::move(graph); cached_value.graph_ = pe_and_graph.second;
return std::make_pair(parallel_executor, /*is_new_created=*/true); return std::make_pair(pe_and_graph.first, /*is_new_created=*/true);
} else { } else {
VLOG(1) << "get exe_info from cache by: " << program_id VLOG(1) << "get exe_info from cache by: " << program_id
<< " is_grad: " << is_grad; << " is_grad: " << is_grad;
......
...@@ -127,11 +127,20 @@ class ExecutorInfoCache { ...@@ -127,11 +127,20 @@ class ExecutorInfoCache {
using CacheInfo = using CacheInfo =
std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>; std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;
using PEAndGraphPair =
std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc, CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
const platform::Place& place, const platform::Place& place,
int64_t start_op_index, int64_t end_op_index, int64_t start_op_index, int64_t end_op_index,
bool is_grad, int64_t program_id, bool is_grad, int64_t program_id,
framework::Scope* scope); framework::Scope* scope);
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc,
const platform::Place& place,
int64_t start_op_index,
int64_t end_op_index,
framework::Scope* scope);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,8 @@ ...@@ -41,6 +41,8 @@
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
#include "paddle/fluid/operators/cuda_graph_with_in_out.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -87,6 +87,8 @@ namespace operators { ...@@ -87,6 +87,8 @@ namespace operators {
class CudnnRNNCache; class CudnnRNNCache;
class CUDAGraphWithInOuts;
namespace reader { namespace reader {
class LoDTensorBlockingQueueHolder; class LoDTensorBlockingQueueHolder;
class OrderedMultiDeviceLoDTensorBlockingQueueHolder; class OrderedMultiDeviceLoDTensorBlockingQueueHolder;
...@@ -189,7 +191,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -189,7 +191,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#if defined(PADDLE_WITH_CNCL) #if defined(PADDLE_WITH_CNCL)
cnclCliqueId, cnclCliqueId,
#endif #endif
int, float, Vocab>; std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>, int, float,
Vocab>;
template <typename T> template <typename T>
struct VarTypeTrait { struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type"); static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
......
...@@ -123,6 +123,8 @@ class CUDAGraphAllocator ...@@ -123,6 +123,8 @@ class CUDAGraphAllocator
: underlying_allocator_(allocator) {} : underlying_allocator_(allocator) {}
public: public:
~CUDAGraphAllocator() { VLOG(10) << "CUDAGraphAllocator destructed"; }
static std::shared_ptr<Allocator> Create( static std::shared_ptr<Allocator> Create(
const std::shared_ptr<Allocator>& allocator) { const std::shared_ptr<Allocator>& allocator) {
return std::shared_ptr<Allocator>(new CUDAGraphAllocator(allocator)); return std::shared_ptr<Allocator>(new CUDAGraphAllocator(allocator));
...@@ -973,7 +975,7 @@ AllocatorFacade& AllocatorFacade::Instance() { ...@@ -973,7 +975,7 @@ AllocatorFacade& AllocatorFacade::Instance() {
AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const { AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) { if (UNLIKELY(IsCUDAGraphCapturing())) {
auto id = platform::CUDAGraph::CapturingID(); auto id = platform::CUDAGraph::CapturingPoolID();
auto iter = cuda_graph_map_.find(id); auto iter = cuda_graph_map_.find(id);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, cuda_graph_map_.end(), iter, cuda_graph_map_.end(),
...@@ -1116,7 +1118,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place, ...@@ -1116,7 +1118,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place,
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) {
PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth, PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"CUDA Graph is only supported when the " "CUDA Graph is only supported when the "
...@@ -1124,23 +1126,32 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { ...@@ -1124,23 +1126,32 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) {
"FLAGS_allocator_strategy=\"%s\"", "FLAGS_allocator_strategy=\"%s\"",
FLAGS_allocator_strategy)); FLAGS_allocator_strategy));
auto& allocator = cuda_graph_map_[id]; auto& allocator = cuda_graph_map_[id];
PADDLE_ENFORCE_EQ( auto& ref_cnt = cuda_graph_ref_cnt_[id];
allocator.get(), nullptr, if (allocator.get() == nullptr) {
platform::errors::InvalidArgument( allocator.reset(
"The memory pool of the CUDA Graph with ID %d have been prepared.", new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false));
id)); VLOG(10) << "Create memory pool for CUDA Graph with memory ID " << id;
allocator.reset(new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); } else {
VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id;
VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; }
++ref_cnt;
} }
void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) {
auto iter = cuda_graph_map_.find(id); auto ref_cnt_iter = cuda_graph_ref_cnt_.find(id);
PADDLE_ENFORCE_NE(iter, cuda_graph_map_.end(), PADDLE_ENFORCE_NE(ref_cnt_iter, cuda_graph_ref_cnt_.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot find CUDA Graph with ID = %d", id)); "Cannot find CUDA Graph with memory ID = %d", id));
cuda_graph_map_.erase(iter); auto& ref_cnt = ref_cnt_iter->second;
VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; --ref_cnt;
if (ref_cnt == 0) {
cuda_graph_map_.erase(id);
cuda_graph_ref_cnt_.erase(ref_cnt_iter);
VLOG(10) << "Remove memory pool of CUDA Graph with memory ID " << id;
} else {
VLOG(10) << "Decrease memory pool ID " << id << " reference count to be "
<< ref_cnt;
}
} }
#endif #endif
#endif #endif
......
...@@ -89,8 +89,8 @@ class AllocatorFacade { ...@@ -89,8 +89,8 @@ class AllocatorFacade {
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); void PrepareMemoryPoolForCUDAGraph(int64_t id);
void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); void RemoveMemoryPoolOfCUDAGraph(int64_t id);
#endif #endif
// TODO(yy): Allocate a Copy-On-Write allocation? // TODO(yy): Allocate a Copy-On-Write allocation?
...@@ -98,8 +98,9 @@ class AllocatorFacade { ...@@ -98,8 +98,9 @@ class AllocatorFacade {
AllocatorFacade(); AllocatorFacade();
AllocatorFacadePrivate* m_; AllocatorFacadePrivate* m_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unordered_map<CUDAGraphID, std::unique_ptr<AllocatorFacadePrivate>> std::unordered_map<int64_t, std::unique_ptr<AllocatorFacadePrivate>>
cuda_graph_map_; cuda_graph_map_;
std::unordered_map<int64_t, int64_t> cuda_graph_ref_cnt_;
#endif #endif
}; };
......
...@@ -107,6 +107,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin ...@@ -107,6 +107,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS cast_kernel) op_library(quantize_linear_op DEPS cast_kernel)
op_library(save_combine_op DEPS string_array) op_library(save_combine_op DEPS string_array)
op_library(load_combine_op DEPS string_array) op_library(load_combine_op DEPS string_array)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#endif
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_CUDA
class CUDAGraphWithInOuts {
public:
template <typename Callable>
CUDAGraphWithInOuts(Callable &&callable, platform::CUDAPlace place,
const std::vector<const framework::Tensor *> &in_ptrs,
cudaStreamCaptureMode mode, int64_t pool_id) {
in_indices_.resize(in_ptrs.size());
ins_.reserve(in_ptrs.size());
int64_t valid_in_idx = 0;
for (size_t i = 0; i < in_ptrs.size(); ++i) {
if (in_ptrs[i] == nullptr) {
in_indices_[i] = -1;
} else {
in_indices_[i] = (valid_in_idx++);
ins_.push_back(*in_ptrs[i]);
}
}
platform::BeginCUDAGraphCapture(place, mode, pool_id);
auto out_ptrs = callable(in_ptrs);
graph_ = platform::EndCUDAGraphCapture();
graph_->Replay();
out_indices_.resize(out_ptrs.size());
outs_.reserve(out_ptrs.size());
int64_t valid_out_idx = 0;
for (size_t i = 0; i < out_ptrs.size(); ++i) {
if (out_ptrs[i] == nullptr) {
out_indices_[i] = -1;
} else {
out_indices_[i] = (valid_out_idx++);
outs_.push_back(*out_ptrs[i]);
}
}
}
void Run(const std::vector<const framework::Tensor *> &ins) {
PADDLE_ENFORCE_EQ(
ins.size(), in_indices_.size(),
phi::errors::InvalidArgument("The input number does not match."));
for (size_t i = 0; i < in_indices_.size(); ++i) {
if (in_indices_[i] >= 0) {
auto *dst = &ins_[in_indices_[i]];
framework::TensorCopy(*ins[i], dst->place(), dst);
}
}
graph_->Replay();
}
std::vector<framework::Tensor *> GetOutputs() {
std::vector<framework::Tensor *> outs(out_indices_.size());
for (size_t i = 0; i < out_indices_.size(); ++i) {
if (out_indices_[i] >= 0) {
outs[i] = &outs_[out_indices_[i]];
}
}
return outs;
}
int64_t PoolID() const { return graph_->PoolID(); }
private:
std::unique_ptr<platform::CUDAGraph> graph_;
std::vector<framework::Tensor> ins_;
std::vector<framework::Tensor> outs_;
std::vector<int64_t> in_indices_;
std::vector<int64_t> out_indices_;
};
template <typename Callable>
static std::unique_ptr<CUDAGraphWithInOuts> CaptureCUDAGraph(
Callable &&callable, const framework::ExecutionContext &ctx,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, cudaStreamCaptureMode mode,
int64_t pool_id) {
std::vector<const framework::Tensor *> inputs;
for (const auto &name : input_names) {
auto input_tensors = ctx.MultiInput<framework::Tensor>(name);
inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end());
}
auto func = [&](const std::vector<const framework::Tensor *> &inputs) {
callable(ctx);
std::vector<framework::Tensor *> outputs;
for (const auto &name : output_names) {
auto output_tensors = ctx.MultiOutput<framework::Tensor>(name);
outputs.insert(outputs.end(), output_tensors.begin(),
output_tensors.end());
}
return outputs;
};
return std::make_unique<CUDAGraphWithInOuts>(func, ctx.GetPlace(), inputs,
mode, pool_id);
}
static void ExecuteCUDAGraph(const framework::ExecutionContext &ctx,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
CUDAGraphWithInOuts *graph) {
std::vector<const framework::Tensor *> inputs;
for (const auto &name : input_names) {
auto input_tensors = ctx.MultiInput<framework::Tensor>(name);
inputs.insert(inputs.end(), input_tensors.begin(), input_tensors.end());
}
graph->Run(inputs);
auto outputs = graph->GetOutputs();
size_t idx = 0;
for (const auto &name : output_names) {
auto output_tensors = ctx.MultiOutput<framework::Tensor>(name);
for (auto *out_t : output_tensors) {
if (outputs[idx] != nullptr) {
*out_t = *outputs[idx];
} else {
PADDLE_ENFORCE_EQ(
out_t, nullptr,
phi::errors::InvalidArgument(
"The %d-th output variable should be nullptr.", idx));
}
++idx;
}
}
}
#else
class CUDAGraphWithInOuts {};
#endif
} // namespace operators
} // namespace paddle
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
...@@ -195,9 +196,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, ...@@ -195,9 +196,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t main_offset = size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize); size / (block_size * kVecSize) * (block_size * kVecSize);
VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>( PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
size, seed_data, dropout_prob, x_data, mask_data, y_data, !is_fix_seed, (VectorizedRandomGenerator<T, uint8_t>), grid_size,
upscale_in_train, increment, main_offset); block_size, 0, stream, offset, KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment, main_offset);
} else { } else {
if (upscale_in_train) { if (upscale_in_train) {
// todo: can y share with data with x directly? // todo: can y share with data with x directly?
......
...@@ -90,6 +90,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -90,6 +90,8 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"computes double grad.") "computes double grad.")
.AsDuplicable() .AsDuplicable()
.AsDispensable(); .AsDispensable();
AddOutput("CUDAGraph", "The output CUDA Graph when use_cuda_graph=True.")
.AsDispensable();
AddAttr<BlockDesc*>("global_block", AddAttr<BlockDesc*>("global_block",
"(BlockDesc *)" "(BlockDesc *)"
"The global block of executed program desc."); "The global block of executed program desc.");
...@@ -107,6 +109,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -107,6 +109,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"program_id", "program_id",
"(int64_t)" "(int64_t)"
"The unique hash id used as cache key for ExecutorInfoCache."); "The unique hash id used as cache key for ExecutorInfoCache.");
AddAttr<std::string>("cuda_graph_capture_mode",
"(str, default '') The CUDA Graph capture mode. "
"Default '' means no CUDA Graph capturing.")
.SetDefault("");
AddAttr<int64_t>("cuda_graph_pool_id",
"(int64_t, default 0) The CUDA Graph memory pool ID.")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
RunProgram operator. RunProgram operator.
...@@ -191,6 +200,9 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -191,6 +200,9 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetInput("OutScope", this->Output("OutScope")); grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetInput("DOut", this->Output("DOut")); grad_op->SetInput("DOut", this->Output("DOut"));
if (this->HasOutput("CUDAGraph")) {
grad_op->SetInput("CUDAGraph", this->Output("CUDAGraph"));
}
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
auto block_desc = auto block_desc =
......
...@@ -34,6 +34,9 @@ limitations under the License. */ ...@@ -34,6 +34,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cuda_graph_with_in_out.h"
#endif
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -167,13 +170,84 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars, ...@@ -167,13 +170,84 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
} }
} }
#ifdef PADDLE_WITH_CUDA
static cudaStreamCaptureMode StringToCUDAGraphCaptureMode(
const std::string &mode) {
if (mode == "global") {
return cudaStreamCaptureModeGlobal;
} else if (mode == "thread_local") {
return cudaStreamCaptureModeThreadLocal;
} else if (mode == "relaxed") {
return cudaStreamCaptureModeRelaxed;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported CUDA Graph capture mode %s", mode));
}
}
#endif
} // namespace details } // namespace details
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class RunProgramOpKernel : public framework::OpKernel<T> { class RunProgramOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto &capture_mode = ctx.Attr<std::string>("cuda_graph_capture_mode");
auto is_test = ctx.Attr<bool>("is_test");
if (capture_mode.empty()) {
ComputeImpl(ctx, is_test, false);
return;
}
#ifdef PADDLE_WITH_CUDA
auto mode = details::StringToCUDAGraphCaptureMode(capture_mode);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
auto *graph_var = ctx.OutputVar("CUDAGraph");
PADDLE_ENFORCE_NOT_NULL(
graph_var,
phi::errors::InvalidArgument("Output(CUDAGraph) must exist when "
"cuda_graph_capture_mode is valid."));
using GraphVecType = std::vector<std::unique_ptr<CUDAGraphWithInOuts>>;
auto &inner_graphs = *(graph_var->GetMutable<GraphVecType>());
inner_graphs.resize(std::max<size_t>(3, inner_graphs.size()));
size_t graph_idx = is_test ? 0 : 1;
if (inner_graphs[graph_idx].get() == nullptr) {
int64_t pool_id;
if (inner_graphs[1 - graph_idx].get() != nullptr) {
pool_id = inner_graphs[1 - graph_idx]->PoolID();
} else {
pool_id = ctx.Attr<int64_t>("cuda_graph_pool_id");
}
framework::PEAndGraphPair pe_and_graph;
auto callable = [this, is_test, &pe_and_graph](
const framework::ExecutionContext &exe_ctx) {
pe_and_graph = ComputeImpl(exe_ctx, is_test, true);
};
inner_graphs[graph_idx] = CaptureCUDAGraph(
callable, ctx, {"X"}, {"Out", "DOut"}, mode, pool_id);
VLOG(10) << "Capture Forward CUDA Graph";
} else {
VLOG(10) << "Run Forward CUDA Graph directly";
ExecuteCUDAGraph(ctx, {"X"}, {"Out", "DOut"},
inner_graphs[graph_idx].get());
}
#else
PADDLE_THROW(
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
#endif
}
private:
framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx,
bool is_test,
bool use_cuda_graph) const {
VLOG(2) << "RunProgramOpKernel Compute"; VLOG(2) << "RunProgramOpKernel Compute";
framework::PEAndGraphPair pe_and_graph;
// Step 1. prepare inputs, outputs, attrs // Step 1. prepare inputs, outputs, attrs
auto &input_vars = ctx.MultiInputVar("X"); auto &input_vars = ctx.MultiInputVar("X");
auto &param_vars = ctx.MultiInputVar("Params"); auto &param_vars = ctx.MultiInputVar("Params");
...@@ -192,7 +266,6 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -192,7 +266,6 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto start_op_index = ctx.Attr<int64_t>("start_op_index"); auto start_op_index = ctx.Attr<int64_t>("start_op_index");
auto end_op_index = ctx.Attr<int64_t>("end_op_index"); auto end_op_index = ctx.Attr<int64_t>("end_op_index");
auto is_test = ctx.Attr<bool>("is_test");
auto program_id = ctx.Attr<int64_t>("program_id"); auto program_id = ctx.Attr<int64_t>("program_id");
// NOTE(chenweihang): In order not to add new variable type, use vector // NOTE(chenweihang): In order not to add new variable type, use vector
...@@ -223,15 +296,29 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -223,15 +296,29 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
if (end_op_index > start_op_index) { if (end_op_index > start_op_index) {
auto *program = global_block->Program(); auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache( bool is_new_created;
*program, ctx.GetPlace(), start_op_index, end_op_index, if (use_cuda_graph) {
/*is_grad=*/false, program_id, &scope); pe_and_graph = framework::CreateFixOrderExecutorInfo(
auto &parallel_executor = cache_info.first; *program, ctx.GetPlace(), start_op_index, end_op_index, &scope);
is_new_created = true;
} else {
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad=*/false, program_id, &scope);
pe_and_graph.first = cache_info.first;
is_new_created = cache_info.second;
}
auto &parallel_executor = pe_and_graph.first;
// all out_vars are skip_eager_var // all out_vars are skip_eager_var
std::vector<std::string> tmp_vars;
auto &skip_eager_delete_vars = auto &skip_eager_delete_vars =
framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( use_cuda_graph
program_id, false); ? tmp_vars
if (cache_info.second /*is_new_created*/) { : framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, false);
if (is_new_created) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names); parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names);
skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), skip_eager_delete_vars.insert(skip_eager_delete_vars.end(),
output_var_names.begin(), output_var_names.begin(),
...@@ -263,6 +350,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -263,6 +350,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace()); if (FLAGS_use_mkldnn) platform::DontClearMKLDNNCache(ctx.GetPlace());
#endif #endif
return pe_and_graph;
} }
}; };
...@@ -270,14 +358,68 @@ template <typename DeviceContext, typename T> ...@@ -270,14 +358,68 @@ template <typename DeviceContext, typename T>
class RunProgramGradOpKernel : public framework::OpKernel<T> { class RunProgramGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto &capture_mode = ctx.Attr<std::string>("cuda_graph_capture_mode");
if (capture_mode.empty()) {
ComputeImpl(ctx, false);
return;
}
#ifdef PADDLE_WITH_CUDA
auto mode = details::StringToCUDAGraphCaptureMode(capture_mode);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
auto *graph_var =
const_cast<framework::Variable *>(ctx.InputVar("CUDAGraph"));
PADDLE_ENFORCE_NOT_NULL(
graph_var,
phi::errors::InvalidArgument("Output(CUDAGraph) must exist when "
"cuda_graph_capture_mode is valid."));
auto &inner_graphs = *(
graph_var
->GetMutable<std::vector<std::unique_ptr<CUDAGraphWithInOuts>>>());
const size_t graph_idx = 2;
if (inner_graphs[graph_idx].get() == nullptr) {
framework::PEAndGraphPair pe_and_graph;
auto callable =
[this, &pe_and_graph](const framework::ExecutionContext &exe_ctx) {
pe_and_graph = ComputeImpl(exe_ctx, true);
};
int64_t pool_id = inner_graphs[0].get() != nullptr
? inner_graphs[0]->PoolID()
: inner_graphs[1]->PoolID();
inner_graphs[graph_idx] =
CaptureCUDAGraph(callable, ctx, {framework::GradVarName("Out")},
{framework::GradVarName("X")}, mode, pool_id);
VLOG(10) << "Capture Backward CUDA Graph";
} else {
ExecuteCUDAGraph(ctx, {framework::GradVarName("Out")},
{framework::GradVarName("X")},
inner_graphs[graph_idx].get());
VLOG(10) << "Run Backward CUDA Graph directly";
}
#else
PADDLE_THROW(
phi::errors::InvalidArgument("The cuda_graph_capture_mode is only "
"valid when using NVIDIA GPU."));
#endif
}
private:
framework::PEAndGraphPair ComputeImpl(const framework::ExecutionContext &ctx,
bool use_cuda_graph) const {
VLOG(2) << "RunProgramGradOpKernel Compute"; VLOG(2) << "RunProgramGradOpKernel Compute";
framework::PEAndGraphPair pe_and_graph;
// Step 1. prepare inputs and outputs // Step 1. prepare inputs and outputs
auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out")); auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out"));
auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X")); auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X"));
auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params")); auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params"));
// if all output vars are set to stop_gradient, grad op no need to executed // if all output vars are set to stop_gradient, grad op no need to executed
if (input_grad_vars.empty() && param_grad_vars.empty()) return; if (input_grad_vars.empty() && param_grad_vars.empty()) {
return pe_and_graph;
}
auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out")); auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out"));
// NOTE: after PR22939 [Add double grad] merged, the grad op maker's // NOTE: after PR22939 [Add double grad] merged, the grad op maker's
...@@ -321,15 +463,27 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> { ...@@ -321,15 +463,27 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
if (end_op_index > start_op_index) { if (end_op_index > start_op_index) {
// Step 2. prepare executor and scope // Step 2. prepare executor and scope
auto *program = global_block->Program(); auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache( bool is_new_created;
*program, ctx.GetPlace(), start_op_index, end_op_index, if (use_cuda_graph) {
/*is_grad*/ true, program_id, &scope); pe_and_graph = framework::CreateFixOrderExecutorInfo(
auto &parallel_executor = cache_info.first; *program, ctx.GetPlace(), start_op_index, end_op_index, &scope);
is_new_created = true;
} else {
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true, program_id, &scope);
pe_and_graph.first = cache_info.first;
is_new_created = cache_info.second;
}
auto &parallel_executor = pe_and_graph.first;
std::vector<std::string> tmp_vars;
auto &skip_eager_delete_vars = auto &skip_eager_delete_vars =
framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( use_cuda_graph
program_id, true); ? tmp_vars
if (cache_info.second /*is_new_created*/) { : framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars(
program_id, true);
if (is_new_created) {
parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, parallel_executor->SkipMemoryReuse(/*scope_idx=*/0,
output_grad_var_names); output_grad_var_names);
...@@ -360,6 +514,7 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> { ...@@ -360,6 +514,7 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
global_inner_scope->DeleteScope(&scope); global_inner_scope->DeleteScope(&scope);
VLOG(2) << "The number of sub scopes after backward: " VLOG(2) << "The number of sub scopes after backward: "
<< global_inner_scope->kids().size(); << global_inner_scope->kids().size();
return pe_and_graph;
} }
}; };
......
...@@ -16,23 +16,33 @@ ...@@ -16,23 +16,33 @@
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place, void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode) { cudaStreamCaptureMode mode, int64_t pool_id) {
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->cudnn_workspace_handle().ResetWorkspace();
auto stream = dev_ctx->stream(); auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode); CUDAGraph::BeginCapture(place, stream, mode);
auto id = CUDAGraph::CapturingID();
auto old_value = FLAGS_use_stream_safe_cuda_allocator;
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = false;
}
pool_id = CUDAGraph::SetMemoryPoolID(pool_id);
memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph(
id); pool_id);
AddResetCallbackIfCapturingCUDAGraph([id] { if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true;
}
AddResetCallbackIfCapturingCUDAGraph([pool_id] {
memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph(
id); pool_id);
}); });
} }
......
...@@ -23,10 +23,53 @@ ...@@ -23,10 +23,53 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \
__block, __sm_size, __stream, \
__seed_inc, __seed_expr, \
__offset_expr, ...) \
do { \
if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \
using __Helper = \
::paddle::platform::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = \
::paddle::platform::DeviceContextPool::Instance().GetByPlace( \
::paddle::platform::CUDAGraph::CapturingPlace()); \
auto __set_seed_func = [=]( \
::paddle::platform::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::paddle::operators::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, __kernel_func, __grid, \
__block, __sm_size, __stream, \
__seed_inc, __seed_expr, \
__offset_expr, ...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif
// NOTE: These APIs are not thread-safe. // NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place, void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode); cudaStreamCaptureMode mode,
int64_t pool_id = CUDAGraph::kInvalidPoolID);
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture(); std::unique_ptr<CUDAGraph> EndCUDAGraphCapture();
#endif #endif
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" #include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include <queue>
#include <unordered_map>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -20,6 +23,69 @@ namespace platform { ...@@ -20,6 +23,69 @@ namespace platform {
std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr}; std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none}; paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none};
static std::vector<cudaGraphNode_t> ToposortCUDAGraph(cudaGraph_t graph) {
size_t num_nodes;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
std::vector<cudaGraphNode_t> nodes(num_nodes);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
size_t num_edges;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetEdges(graph, nullptr, nullptr, &num_edges));
std::vector<cudaGraphNode_t> from(num_edges), to(num_edges);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetEdges(graph, from.data(), to.data(), &num_edges));
std::unordered_map<cudaGraphNode_t, std::unordered_set<cudaGraphNode_t>>
in_edges, out_edges;
for (auto node : nodes) {
in_edges[node];
out_edges[node];
}
for (size_t i = 0; i < num_edges; ++i) {
in_edges[to[i]].insert(from[i]);
out_edges[from[i]].insert(to[i]);
}
std::queue<cudaGraphNode_t> q;
for (const auto &pair : in_edges) {
if (pair.second.empty()) {
q.push(pair.first);
}
}
nodes.clear();
while (!q.empty()) {
auto cur = q.front();
q.pop();
nodes.push_back(cur);
for (auto out_node : out_edges.at(cur)) {
auto &in_nodes = in_edges.at(out_node);
in_nodes.erase(cur);
if (in_nodes.empty()) {
q.push(out_node);
}
}
}
PADDLE_ENFORCE_EQ(
nodes.size(), num_nodes,
phi::errors::InvalidArgument("Toposort error, this may be a bug."));
return nodes;
}
CUDAGraphID CUDAGraph::UniqueID() {
static std::atomic<CUDAGraphID> id;
return id.fetch_add(1);
}
int64_t CUDAGraph::UniqueMemoryPoolID() {
static std::atomic<int64_t> id(CUDAGraph::kDefaultPoolID + 1);
return id.fetch_add(1);
}
void CUDAGraph::Reset() { void CUDAGraph::Reset() {
if (is_reset_) return; if (is_reset_) return;
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
...@@ -46,9 +112,16 @@ void CUDAGraph::Replay() { ...@@ -46,9 +112,16 @@ void CUDAGraph::Replay() {
PADDLE_ENFORCE_EQ(is_reset_, false, PADDLE_ENFORCE_EQ(is_reset_, false,
errors::PermissionDenied( errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called.")); "Cannot replay the CUDA Graph after reset is called."));
for (auto exec_graph : exec_graphs_) { size_t n = exec_graphs_.size();
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graph, stream_)); for (size_t i = 0; i < n; ++i) {
if (!is_first_run_) {
for (auto &hook : pre_hooks_[i]) {
hook(exec_graphs_[i]);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graphs_[i], stream_));
} }
is_first_run_ = false;
#endif #endif
} }
...@@ -72,7 +145,8 @@ void CUDAGraph::BeginSegmentCapture() { ...@@ -72,7 +145,8 @@ void CUDAGraph::BeginSegmentCapture() {
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated.")); "CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_ VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size(); << ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
#endif #endif
} }
...@@ -112,15 +186,57 @@ void CUDAGraph::EndSegmentCapture() { ...@@ -112,15 +186,57 @@ void CUDAGraph::EndSegmentCapture() {
if (num_nodes == 0) { if (num_nodes == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph)); PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph));
VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_ VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size(); << ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
return; return;
} }
auto sorted_nodes = ToposortCUDAGraph(graph);
capturing_graph_->pre_hooks_.emplace_back();
std::unordered_set<cudaGraphNode_t> visited;
VLOG(10) << "SetSeedFunc number : "
<< capturing_graph_->set_seed_funcs_.size();
for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) {
bool found = false;
for (auto node : sorted_nodes) {
if (visited.count(node) > 0) continue;
cudaGraphNodeType type;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeKernel) {
cudaKernelNodeParams params;
auto err = cudaGraphKernelNodeGetParams(node, &params);
if (err == cudaErrorInvalidDeviceFunction) {
continue;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(err);
}
CUDAKernelParams kernel_params(&params);
if (set_seed_func(&kernel_params, true)) {
capturing_graph_->pre_hooks_.back().push_back(
[set_seed_func, node, params](cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(&params);
set_seed_func(&kernel_params, false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams(
exec_graph, node, &params));
});
visited.insert(node);
found = true;
break;
}
}
}
PADDLE_ENFORCE_EQ(found, true,
phi::errors::InvalidArgument(
"Cannot find the corresponding random CUDA kernel."));
}
capturing_graph_->set_seed_funcs_.clear();
cudaGraphExec_t exec_graph; cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0)); cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_ VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size(); << ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
capturing_graph_->graphs_.emplace_back(graph); capturing_graph_->graphs_.emplace_back(graph);
capturing_graph_->exec_graphs_.emplace_back(exec_graph); capturing_graph_->exec_graphs_.emplace_back(exec_graph);
#endif #endif
......
...@@ -32,6 +32,70 @@ ...@@ -32,6 +32,70 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}
class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}
const void *func() const { return params_->func; }
template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
}
private:
const cudaKernelNodeParams *params_;
};
template <typename F, F f>
struct IsSameKernelHelper;
template <typename Return, typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, TupleT>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}
constexpr auto NewIsEnd =
(IDX + 1 == sizeof(std::tuple_size<TupleT>::value));
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};
template <typename TupleT, size_t IDX>
struct Impl<TupleT, IDX, true> {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
return true;
}
};
public:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
static_assert(kNumArgs == sizeof...(Args), "Argument number not match");
auto args_tuple = std::make_tuple(args...);
using TupleT = typename std::decay<decltype(args_tuple)>::type;
return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
}
};
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {} static void ThrowErrorIfNotSupportCUDAGraph() {}
#else #else
...@@ -61,10 +125,35 @@ class CUDAGraph { ...@@ -61,10 +125,35 @@ class CUDAGraph {
} }
public: public:
static constexpr int64_t kDefaultPoolID = 0;
static constexpr int64_t kInvalidPoolID = -1;
~CUDAGraph() { Reset(); } ~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; } CUDAGraphID ID() const { return id_; }
static int64_t SetMemoryPoolID(int64_t pool_id) {
auto &pool_id_ = capturing_graph_->pool_id_;
PADDLE_ENFORCE_EQ(
pool_id_, kInvalidPoolID,
phi::errors::InvalidArgument("Cannot reset memory pool id twice, the "
"former memory pool id is %d.",
pool_id_));
if (pool_id <= kInvalidPoolID) {
pool_id_ = UniqueMemoryPoolID();
} else {
PADDLE_ENFORCE_GE(
pool_id, kDefaultPoolID,
phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id));
pool_id_ = pool_id;
}
return pool_id_;
}
int64_t PoolID() const { return pool_id_; }
static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }
void Replay(); void Replay();
void Reset(); void Reset();
...@@ -120,12 +209,17 @@ class CUDAGraph { ...@@ -120,12 +209,17 @@ class CUDAGraph {
} }
} }
private: using SetSeedFunc = std::function<bool(CUDAKernelParams *, bool)>;
static CUDAGraphID UniqueID() { static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
static std::atomic<CUDAGraphID> id; std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
return id.fetch_add(1); capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
} }
static int64_t UniqueMemoryPoolID();
private:
static CUDAGraphID UniqueID();
private: private:
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
std::vector<cudaGraph_t> graphs_; std::vector<cudaGraph_t> graphs_;
...@@ -135,10 +229,17 @@ class CUDAGraph { ...@@ -135,10 +229,17 @@ class CUDAGraph {
cudaStream_t stream_{nullptr}; cudaStream_t stream_{nullptr};
platform::CUDAPlace place_; platform::CUDAPlace place_;
CUDAGraphID id_; CUDAGraphID id_;
int64_t pool_id_{kInvalidPoolID};
std::vector<std::function<void()>> callbacks_; std::vector<std::function<void()>> callbacks_;
bool is_reset_{false}; bool is_reset_{false};
std::mutex mtx_; std::mutex mtx_;
std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
std::mutex func_mtx_;
bool is_first_run_{true};
static paddle::optional<std::thread::id> capturing_thread_id_; static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_; static std::unique_ptr<CUDAGraph> capturing_graph_;
}; };
......
...@@ -27,7 +27,8 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args, ...@@ -27,7 +27,8 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args,
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
framework::AttributeMap attrs; framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("run_program", args, 5, PyTuple_GET_SIZE(args), // TODO(zengjinle): support CUDA Graph on eager mode
ConstructAttrMapFromPyArgs("run_program", args, 6, PyTuple_GET_SIZE(args),
attrs); attrs);
tstate = PyEval_SaveThread(); tstate = PyEval_SaveThread();
......
...@@ -640,10 +640,11 @@ void CastPyArg2AttrBlock(PyObject* obj, ...@@ -640,10 +640,11 @@ void CastPyArg2AttrBlock(PyObject* obj,
void ConstructAttrMapFromPyArgs( void ConstructAttrMapFromPyArgs(
const std::string& op_type, PyObject* args, ssize_t attr_start, const std::string& op_type, PyObject* args, ssize_t attr_start,
ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2, 0,
(attr_end - attr_start) % 2, 0, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The number of arguments for attributes should be even "
"The number of arguments for attributes should be even.")); "but attr_start = %d, attr_end = %d.",
attr_start, attr_end));
auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]); auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);
......
...@@ -182,7 +182,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -182,7 +182,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"run_program", {"DOut"}}, {"run_program", {"DOut", "CUDAGraph"}},
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
...@@ -267,7 +267,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -267,7 +267,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"moving_average_abs_max_scale", {"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"rnn", {"DropoutState"}}, {"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}}, {"run_program", {"Out", "DOut", "OutScope", "CUDAGraph"}},
{"clear_float_status", {"FloatStatusOut"}}, {"clear_float_status", {"FloatStatusOut"}},
{"get_float_status", {"FloatStatusOut"}}, {"get_float_status", {"FloatStatusOut"}},
{"assign", {"Out"}}, {"assign", {"Out"}},
......
...@@ -604,6 +604,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -604,6 +604,8 @@ PYBIND11_MODULE(core_noavx, m) {
place, static_cast<cudaStreamCaptureMode>(mode)); place, static_cast<cudaStreamCaptureMode>(mode));
}) })
.def_static("end_capture", &platform::EndCUDAGraphCapture) .def_static("end_capture", &platform::EndCUDAGraphCapture)
.def_static("gen_new_memory_pool_id",
&platform::CUDAGraph::UniqueMemoryPoolID)
.def("replay", &platform::CUDAGraph::Replay) .def("replay", &platform::CUDAGraph::Replay)
.def("reset", &platform::CUDAGraph::Reset) .def("reset", &platform::CUDAGraph::Reset)
.def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles); .def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles);
......
...@@ -17,15 +17,23 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA ...@@ -17,15 +17,23 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA
if is_compiled_with_cuda() and not is_compiled_with_rocm(): if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
def is_cuda_graph_supported():
return True
else: else:
CoreCUDAGraph = None CoreCUDAGraph = None
def is_cuda_graph_supported():
return False
ALL_MODES = ["global", "thread_local", "relaxed"]
class CUDAGraph: class CUDAGraph:
def __init__(self, place=None, mode="thread_local"): def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU." assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None self._graph = None
if place is None: if place is None:
device_id = int(os.environ.get('FLAGS_selected_gpus', 0)) device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
...@@ -55,3 +63,25 @@ class CUDAGraph: ...@@ -55,3 +63,25 @@ class CUDAGraph:
if flags is None: if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048) flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags) self._graph.print_to_dot_files(dirname, flags)
def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
assert mode in ALL_MODES
from paddle.jit import to_static
from paddle.nn import Layer
new_function = to_static(function)
if isinstance(function, Layer):
mock_func = new_function.forward
else:
mock_func = new_function
mock_func._cuda_graph_capture_mode = mode
if memory_pool == "default":
mock_func._cuda_graph_pool_id = 0
elif memory_pool == "new":
mock_func._cuda_graph_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
else:
if isinstance(memory_pool, Layer):
mock_func._cuda_graph_pool_id = memory_pool.forward._cuda_graph_pool_id
else:
mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id
return new_function
...@@ -148,6 +148,9 @@ class PartialProgramLayer: ...@@ -148,6 +148,9 @@ class PartialProgramLayer:
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
self._tmp_scope_vec = self._create_scope_vec() self._tmp_scope_vec = self._create_scope_vec()
self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
# Set default mode to train # Set default mode to train
self.training = True self.training = True
...@@ -339,9 +342,15 @@ class PartialProgramLayer: ...@@ -339,9 +342,15 @@ class PartialProgramLayer:
def __call__(self, inputs): def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs) in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index', attrs = [
0, 'end_op_index', self._get_end_op_index(), 'is_test', 'global_block', self.program.desc.block(0), 'start_op_index', 0,
not self.training, 'program_id', self.program_id) 'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id
]
if self._cuda_graph_capture_mode:
attrs.extend(
('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
'cuda_graph_pool_id', self._cuda_graph_pool_id))
self._cast_fp16_if_pure_fp16(in_vars) self._cast_fp16_if_pure_fp16(in_vars)
...@@ -349,7 +358,7 @@ class PartialProgramLayer: ...@@ -349,7 +358,7 @@ class PartialProgramLayer:
self._valid_vars(in_vars), self._valid_vars(in_vars),
self._valid_vars(self._params), self._valid_vars(self._params),
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads, self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
*attrs) self._cuda_graph_vec, *attrs)
self.drop_scope_if_no_grad() self.drop_scope_if_no_grad()
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -471,6 +480,12 @@ class PartialProgramLayer: ...@@ -471,6 +480,12 @@ class PartialProgramLayer:
tmp_scope_vec = [inner_scope] tmp_scope_vec = [inner_scope]
return tmp_scope_vec return tmp_scope_vec
def _create_cuda_graph_vec(self):
var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
core.VarDesc.VarType.RAW, True)
var.stop_gradient = True
return var
def _restore_out(self, out_vars): def _restore_out(self, out_vars):
""" """
Restores same nested outputs by only replacing the Variable with VarBase. Restores same nested outputs by only replacing the Variable with VarBase.
......
...@@ -267,6 +267,8 @@ class StaticFunction(object): ...@@ -267,6 +267,8 @@ class StaticFunction(object):
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
self._kwargs = kwargs self._kwargs = kwargs
self._training = True self._training = True
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
def train(self): def train(self):
if isinstance(self._class_instance, if isinstance(self._class_instance,
...@@ -367,6 +369,9 @@ class StaticFunction(object): ...@@ -367,6 +369,9 @@ class StaticFunction(object):
else: else:
partial_program_layer.training = self._training partial_program_layer.training = self._training
partial_program_layer._cuda_graph_capture_mode = self._cuda_graph_capture_mode
partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id
# 4. return outputs. # 4. return outputs.
try: try:
return partial_program_layer(args) return partial_program_layer(args)
......
...@@ -874,7 +874,7 @@ def _run_dygraph(instance, input, program_holder): ...@@ -874,7 +874,7 @@ def _run_dygraph(instance, input, program_holder):
_valid_vars(input_vars), _valid_vars(input_vars),
_valid_vars(persistable_vars), _valid_vars(persistable_vars),
_valid_vars(output_vars), tmp_scope_vec, _valid_vars(output_vars), tmp_scope_vec,
_valid_vars(double_grad_vars), *attrs) _valid_vars(double_grad_vars), None, *attrs)
# NOTE: [ why need set param's gradient type here ] # NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient # if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just # will be SelectedRows, not LoDTensor. But tracer will just
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import unittest
import numpy as np
from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported
class SimpleModel(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.dropout_1 = paddle.nn.Dropout(0.1)
self.relu = nn.ReLU()
self.dropout_2 = paddle.nn.Dropout(0.5)
self.gelu = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.dropout_1(x)
x = self.relu(x)
x = self.dropout_2(x)
x = self.gelu(x)
return x
class TestSimpleModel(unittest.TestCase):
def setUp(self):
paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 0.0})
def run_base(self, func, use_cuda_graph, memory_pool="default", seed=10):
paddle.seed(seed)
is_layer = isinstance(func, paddle.nn.Layer)
if use_cuda_graph:
func = wrap_cuda_graph(func, memory_pool=memory_pool)
for _ in range(10):
x = paddle.randn([3, 10], dtype='float32')
x.stop_gradient = False
y = x * x + 100
loss = func(y).mean()
loss.backward()
if is_layer:
func.clear_gradients()
return func, x.grad.numpy()
def check(self, func):
if not is_cuda_graph_supported():
return
_, value1 = self.run_base(func, False)
layer, value2 = self.run_base(func, True, "default")
_, value3 = self.run_base(func, True, "new")
_, value4 = self.run_base(func, True, layer)
self.assertTrue(np.array_equal(value1, value2))
self.assertTrue(np.array_equal(value1, value3))
self.assertTrue(np.array_equal(value1, value4))
def test_layer(self):
self.check(SimpleModel(10, 20))
if __name__ == "__main__":
unittest.main()
...@@ -104,7 +104,7 @@ class TestRunProgram(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestRunProgram(unittest.TestCase):
'is_test', False, 'program_id', _hash_with_id(program)) 'is_test', False, 'program_id', _hash_with_id(program))
_C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope], _C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope],
[fake_var], *attrs) [fake_var], None, *attrs)
loss = paddle.mean(out_t) loss = paddle.mean(out_t)
loss.backward() loss.backward()
......
...@@ -188,7 +188,7 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -188,7 +188,7 @@ class RunProgramOpTest(unittest.TestCase):
outputs = self.prepare_dygraph_output() outputs = self.prepare_dygraph_output()
_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'],
outputs['OutScope'], outputs['DOut'], outputs['OutScope'], outputs['DOut'], None,
*self.attrs) *self.attrs)
return outputs['Out'] return outputs['Out']
...@@ -202,7 +202,7 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -202,7 +202,7 @@ class RunProgramOpTest(unittest.TestCase):
outputs = self.prepare_dygraph_output() outputs = self.prepare_dygraph_output()
_C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'],
outputs['OutScope'], outputs['DOut'], outputs['OutScope'], outputs['DOut'], None,
*self.attrs) *self.attrs)
for param in input_param_list: for param in input_param_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册