未验证 提交 8e5ed04d 编写于 作者: P pangyoki 提交者: GitHub

support CUDA Graph for new executor (#49708)

* new exe supports CUDA Graph

* fix

* fix

* fix

* fix FLAGS_use_stream_safe_cuda_allocator in unittest

* insert output of coalesce_tensor op to skip_gc_var

* fix
上级 76302bdc
......@@ -19,14 +19,27 @@
#include "paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.h"
DECLARE_bool(fast_eager_deletion_mode);
DECLARE_bool(new_executor_use_cuda_graph);
namespace paddle {
namespace framework {
bool IsInterpretercoreFastGCEnabled() {
return memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() &&
FLAGS_fast_eager_deletion_mode;
// When using cuda graph, fast GC must be used. Because
// `EventQuery` method in event GC cannot be used in
// cuda graph.
PADDLE_ENFORCE_EQ(memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() == false &&
FLAGS_new_executor_use_cuda_graph,
false,
platform::errors::InvalidArgument(
"When FLAGS_new_executor_use_cuda_graph is true, "
"IsStreamSafeCUDAAllocatorUsed must be true, but "
"got false."));
return (memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() &&
FLAGS_fast_eager_deletion_mode) ||
FLAGS_new_executor_use_cuda_graph;
}
InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
......
......@@ -31,6 +31,7 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/device_manager.h"
PADDLE_DEFINE_EXPORTED_bool(
......@@ -50,6 +51,10 @@ PADDLE_DEFINE_EXPORTED_bool(control_flow_use_new_executor,
DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
DECLARE_bool(new_executor_use_cuda_graph);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DECLARE_bool(sync_nccl_allreduce);
#endif
constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";
......@@ -142,6 +147,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
}
return lhs_prority > rhs_prority;
};
PrepareForCUDAGraphCapture();
}
InterpreterCore::~InterpreterCore() {
......@@ -161,6 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
......@@ -221,6 +229,7 @@ paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
......@@ -240,7 +249,16 @@ paddle::framework::FetchList InterpreterCore::Run(
// return Fetch Tensors
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
}
......@@ -249,6 +267,7 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names, bool need_fetch) {
SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
......@@ -290,7 +309,16 @@ paddle::framework::FetchList InterpreterCore::Run(
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
}
......@@ -504,6 +532,67 @@ void InterpreterCore::BuildInplace() {
}
}
void InterpreterCore::PrepareForCUDAGraphCapture() {
if (!FLAGS_new_executor_use_cuda_graph) return;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ(
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true,
platform::errors::InvalidArgument(
"CUDA Graph is only supported on NVIDIA GPU device."));
// If set true, will call `cudaStreamSynchronize(nccl_stream)`after allreduce.
// which may cause error in cuda graph. This behavior is consistent with PE.
PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce,
false,
platform::errors::InvalidArgument(
"FLAGS_sync_nccl_allreduce must be False to support "
"CUDA Graph capturing."));
// All output vars of coalesce_tensor op should not be gc.
// If fused output var of coalesce_tensor is gc, it will cause accuracy
// problem. The specific reasons need to be analyzed.
for (auto& op_desc : block_.AllOps()) {
if (op_desc->Type() == kCoalesceTensor) {
for (auto& out_var_name : op_desc->OutputArgumentNames()) {
execution_config_.skip_gc_vars.insert(out_var_name);
VLOG(4) << "Insert Var(" << out_var_name << ") into skip_gc_vars.";
}
}
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}
void InterpreterCore::CheckCUDAGraphBeforeRun(
const std::vector<std::string>& feed_names) {
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(
feed_names.empty(),
true,
platform::errors::InvalidArgument(
"Feeding data is not permitted when capturing CUDA Graph."));
PADDLE_ENFORCE_EQ(
FLAGS_new_executor_use_cuda_graph,
true,
platform::errors::InvalidArgument(
"You must turn on FLAGS_new_executor_use_cuda_graph to True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
place_,
platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
}
void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, add next_instr_list to each instr,
// and set the dependecy_count_
......
......@@ -97,6 +97,10 @@ class InterpreterCore {
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/backends/all_context.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
DECLARE_bool(new_executor_use_cuda_graph);
namespace paddle {
namespace platform {
......@@ -43,7 +44,10 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
auto old_value = FLAGS_use_stream_safe_cuda_allocator;
// When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true.
auto old_value = FLAGS_use_stream_safe_cuda_allocator &&
!FLAGS_new_executor_use_cuda_graph;
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = false;
}
......
......@@ -1010,6 +1010,18 @@ PADDLE_DEFINE_EXPORTED_bool(enable_cinn_auto_tune,
#endif
/*
* CUDA Graph related FLAG
* Name: FLAGS_new_executor_use_cuda_graph
* Since Version: 2.4
* Value Range: bool, default=false
* Example: FLAGS_new_executor_use_cuda_graph=true would allow
* new executor to use CUDA Graph.
*/
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_cuda_graph,
false,
"Use CUDA Graph in new executor");
DEFINE_int32(record_pool_max_size,
2000000,
"SlotRecordDataset slot record pool max size");
......
......@@ -26,6 +26,7 @@ from .framework import convert_np_dtype_to_dtype_, _apply_pass
from . import core
from . import unique_name
from . import compiler
from . import set_flags
from .trainer_factory import TrainerFactory
from .trainer_factory import FetchHandlerMonitor
import copy
......@@ -510,6 +511,16 @@ def _is_dy2st_enable_standalone_executor():
]
def _is_cuda_graph_enable_standalone_executor():
return framework._cuda_graph_enable_standalone_executor_ in [
1,
'1',
True,
'True',
'true',
]
def _prepare_fleet_executor():
from ..distributed.fleet.proto import fleet_executor_desc_pb2
......@@ -844,7 +855,19 @@ class _ExecutorCache:
)
build_strategy = compiled_program._build_strategy
# print(f"Program before convert:\n {inner_program}", flush=True)
use_cuda_graph = False
# When using cuda graph, the cuda graph preparation logic in PE is not
# executed, but it is processed in the constructor of new executor.
if (
build_strategy is not None
and build_strategy.allow_cuda_graph_capture
):
use_cuda_graph = True
build_strategy.allow_cuda_graph_capture = False
set_flags({"FLAGS_new_executor_use_cuda_graph": True})
compiled_program._compile(scope, place)
if use_cuda_graph:
build_strategy.allow_cuda_graph_capture = True
ir_graph = framework.IrGraph(compiled_program._graph)
converted_program = ir_graph.to_program()
......@@ -1746,24 +1769,25 @@ class Executor:
)
return False
# Unsupported case 4: CUDA Graph
# Unsupported case 4: async mode
if (
compiled_program._build_strategy is not None
and compiled_program._build_strategy.allow_cuda_graph_capture
and compiled_program._build_strategy.async_mode
):
warnings.warn(
"Standalone executor is not used for CUDA Graph",
"Standalone executor is not used for async mode",
UserWarning,
)
return False
# Unsupported case 5: async mode
# Unsupported case 5: CUDA Graph
if (
compiled_program._build_strategy is not None
and compiled_program._build_strategy.async_mode
and compiled_program._build_strategy.allow_cuda_graph_capture
and not _is_cuda_graph_enable_standalone_executor()
):
warnings.warn(
"Standalone executor is not used for async mode",
"Standalone executor is not used for CUDA Graph when FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=0",
UserWarning,
)
return False
......@@ -1811,8 +1835,13 @@ class Executor:
tensor = core.get_variable_tensor(scope, lr_sheduler._var_name)
# NOTE(dev): `tensor.set(data, self.place)` always call TensorCopySync that is a blocking behavior. So we use `_copy_from` to replace it.
cpu_tensor = _as_lodtensor(data, core.CPUPlace())
# for ipu, tensor is allocated on cpu
if core.is_compiled_with_ipu():
if core.is_cuda_graph_capturing():
warnings.warn(
"Caution!!! When capturing CUDA Graph, the learning rate scheduler would not "
"take any effect! Please set the learning rate manually before each batch!"
)
elif core.is_compiled_with_ipu():
# for ipu, tensor is allocated on cpu
tensor._copy_from(cpu_tensor, tensor._place())
else:
tensor._copy_from(cpu_tensor, self.place)
......
......@@ -86,6 +86,9 @@ _enable_standalone_executor_ = os.environ.get(
_dy2st_enable_standalone_executor_ = os.environ.get(
'FLAGS_DY2ST_USE_STANDALONE_EXECUTOR', 1
)
_cuda_graph_enable_standalone_executor_ = os.environ.get(
'FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR', 0
)
# Some explanation of our execution system 2022.03
# For now we have 3 kinds of execution system, since we refactored dygraph mode to
......
......@@ -1259,3 +1259,7 @@ set_tests_properties(test_parallel_executor_dry_run
PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0")
set_tests_properties(test_parallel_executor_drop_scope
PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0")
set_tests_properties(
test_cuda_graph_static_mode
PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1")
......@@ -18,18 +18,16 @@ import shutil
import unittest
import numpy as np
from simple_nets import simple_fc_net_with_inputs
import paddle
from paddle.device.cuda.graphs import CUDAGraph
from paddle.fluid.dygraph.base import switch_to_static_graph
def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
class TestCUDAGraph(unittest.TestCase):
class TestCUDAGraphInDygraphMode(unittest.TestCase):
def setUp(self):
if can_use_cuda_graph():
paddle.set_flags(
......@@ -46,94 +44,6 @@ class TestCUDAGraph(unittest.TestCase):
np.random.randint(low=0, high=10, size=shape).astype("float32")
)
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not can_use_cuda_graph():
return
seed = 100
loss_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=True
)
loss_no_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=False
)
self.assertEqual(loss_cuda_graph, loss_no_cuda_graph)
def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
batch_size = 1
class_num = 10
image_shape = [batch_size, 784]
label_shape = [batch_size, 1]
paddle.seed(seed)
np.random.seed(seed)
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=image_shape, dtype='float32'
)
label = paddle.static.data(
name="label", shape=label_shape, dtype='int64'
)
image.persistable = True
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]
)
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
build_strategy = paddle.static.BuildStrategy()
build_strategy.allow_cuda_graph_capture = True
build_strategy.fix_op_run_order = True
build_strategy.fuse_all_optimizer_ops = True
compiled_program = paddle.static.CompiledProgram(
main
).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy, places=place
)
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
loss_t = scope.var(loss.name).get_tensor()
lr_var = main.global_block().var(lr._var_name)
self.assertTrue(lr_var.persistable)
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
image_t.set(
np.random.rand(*image_shape).astype('float32'), place
)
label_t.set(
np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
),
place,
)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
cuda_graph.capture_begin()
exe.run(compiled_program)
cuda_graph.capture_end()
if cuda_graph:
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
def test_cuda_graph_dynamic_graph(self):
if not can_use_cuda_graph():
return
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from simple_nets import simple_fc_net_with_inputs
import paddle
from paddle.device.cuda.graphs import CUDAGraph
from paddle.fluid.dygraph.base import switch_to_static_graph
def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
class TestCUDAGraphInStaticMode(unittest.TestCase):
def setUp(self):
if can_use_cuda_graph():
# The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static
# mode is inconsistent with that in dygraph mode.
# In static mode, FLAGS_use_stream_safe_cuda_allocator must be True.
# In dygraph mode, FLAGS_use_stream_safe_cuda_allocator must be False.
# These two types of unittests need to be written separately, because
# the allocator may only be initialized once, and the flag
# `FLAGS_use_stream_safe_cuda_allocator` only takes effect during
# initialization.
paddle.set_flags(
{
'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True,
'FLAGS_use_stream_safe_cuda_allocator': True,
}
)
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not can_use_cuda_graph():
return
seed = 100
loss_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=True
)
loss_no_cuda_graph = self.cuda_graph_static_graph_main(
seed, use_cuda_graph=False
)
self.assertEqual(loss_cuda_graph, loss_no_cuda_graph)
def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
batch_size = 1
class_num = 10
image_shape = [batch_size, 784]
label_shape = [batch_size, 1]
paddle.seed(seed)
np.random.seed(seed)
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=image_shape, dtype='float32'
)
label = paddle.static.data(
name="label", shape=label_shape, dtype='int64'
)
image.persistable = True
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]
)
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
build_strategy = paddle.static.BuildStrategy()
build_strategy.allow_cuda_graph_capture = True
build_strategy.fix_op_run_order = True
build_strategy.fuse_all_optimizer_ops = True
compiled_program = paddle.static.CompiledProgram(
main
).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy, places=place
)
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
loss_t = scope.var(loss.name).get_tensor()
lr_var = main.global_block().var(lr._var_name)
self.assertTrue(lr_var.persistable)
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
image_t.set(
np.random.rand(*image_shape).astype('float32'), place
)
label_t.set(
np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
),
place,
)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
cuda_graph.capture_begin()
exe.run(compiled_program)
cuda_graph.capture_end()
if cuda_graph:
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
if __name__ == "__main__":
unittest.main()
......@@ -623,6 +623,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_dataset_consistency_inspection',
'test_cuda_empty_cache',
'test_cuda_graph',
'test_cuda_graph_static_mode',
'test_disable_signal_handler',
'test_eig_op',
'test_eigh_op',
......@@ -2509,6 +2510,7 @@ TETRAD_PARALLEL_JOB = [
'test_dlpack',
'test_complex_variable',
'test_cuda_graph',
'test_cuda_graph_static_mode',
'test_custom_grad_input',
'test_accuracy_op',
'test_pool1d_api',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册