未验证 提交 9167fda3 编写于 作者: P pangyoki 提交者: GitHub

fix cuda graph error when new executor change feed fetch (#50306)

* change error

* fix
上级 e7a7a7a6
......@@ -531,11 +531,6 @@ void InterpreterCore::BuildInplace() {
void InterpreterCore::PrepareForCUDAGraphCapture() {
if (!FLAGS_new_executor_use_cuda_graph) return;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ(
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true,
platform::errors::InvalidArgument(
......@@ -548,14 +543,23 @@ void InterpreterCore::PrepareForCUDAGraphCapture() {
"FLAGS_sync_nccl_allreduce must be False to support "
"CUDA Graph capturing."));
// All output vars of coalesce_tensor op should not be gc.
// All output vars of coalesce_tensor op should be persistable.
// If fused output var of coalesce_tensor is gc, it will cause accuracy
// problem. The specific reasons need to be analyzed.
for (auto& op_desc : block_.AllOps()) {
if (op_desc->Type() == kCoalesceTensor) {
for (auto& out_var_name : op_desc->OutputArgumentNames()) {
execution_config_.skip_gc_vars.insert(out_var_name);
VLOG(4) << "Insert Var(" << out_var_name << ") into skip_gc_vars.";
// The fused var needs to be set to persistable, not just added to
// skip_gc_vars.
// In the case where the feed fetch var is changed, StandaloneExecutor
// will be newly constructed. If the fused var is not persistable,
// these vars will be recreated and initialized, resulting in
// precision problems.
auto* out_var = op_desc->Block()->FindVarRecursive(out_var_name);
if (out_var) {
out_var->SetPersistable(true);
VLOG(4) << "Mark Var(" << out_var_name << ") as Persistable.";
}
}
}
}
......
......@@ -484,6 +484,10 @@ class CompiledProgram:
self._persistable_vars = list(set(self._persistable_vars))
self._persistable_vars.sort()
if core.is_cuda_graph_capturing():
raise RuntimeError(
"CUDA Graph is not allowed to capture when running the first batch."
)
return core.ParallelExecutor(
places,
self._persistable_vars,
......
......@@ -1230,3 +1230,6 @@ set_tests_properties(test_parallel_executor_drop_scope
set_tests_properties(
test_cuda_graph_static_mode
PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1")
set_tests_properties(
test_cuda_graph_static_mode_error
PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1")
......@@ -26,8 +26,31 @@ def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
def build_program(main, startup, batch_size, class_num):
image_shape = [batch_size, 784]
label_shape = [batch_size, 1]
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=image_shape, dtype='float32'
)
label = paddle.static.data(
name="label", shape=label_shape, dtype='int64'
)
image.persistable = True
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]
)
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
return image, label, loss, lr
class TestCUDAGraphInStaticMode(unittest.TestCase):
def setUp(self):
self.init_data()
if can_use_cuda_graph():
# The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static
# mode is inconsistent with that in dygraph mode.
......@@ -46,6 +69,9 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
}
)
def init_data(self):
self.use_feed_data = False
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not can_use_cuda_graph():
......@@ -70,22 +96,11 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
np.random.seed(seed)
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=image_shape, dtype='float32'
)
label = paddle.static.data(
name="label", shape=label_shape, dtype='int64'
)
image.persistable = True
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]
)
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
image, label, loss, lr = build_program(
main, startup, batch_size, class_num
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
......@@ -108,15 +123,16 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
image_t.set(
np.random.rand(*image_shape).astype('float32'), place
use_feed_data = (
True if batch_id == 0 and self.use_feed_data else False
)
label_t.set(
np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
),
place,
image_np = np.random.rand(*image_shape).astype('float32')
label_np = np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
)
if not use_feed_data:
image_t.set(image_np, place)
label_t.set(label_np, place)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
......@@ -128,12 +144,27 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(compiled_program)
if use_feed_data:
exe.run(
compiled_program,
feed={'image': image_np, 'label': label_np},
)
else:
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
class TestCUDAGraphWhenFeedDataChanges(TestCUDAGraphInStaticMode):
def init_data(self):
# When feed fetch var of new executor changes, a new
# StandaloneExecutor will be newly created. And the
# behavior of capturing cuda graph will change.
# Add test for this case.
self.use_feed_data = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from test_cuda_graph_static_mode import build_program, can_use_cuda_graph
import paddle
from paddle.device.cuda.graphs import CUDAGraph
from paddle.fluid.dygraph.base import switch_to_static_graph
class TestCUDAGraphInFirstBatch(unittest.TestCase):
def setUp(self):
if can_use_cuda_graph():
paddle.set_flags(
{
'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True,
'FLAGS_use_stream_safe_cuda_allocator': True,
}
)
@switch_to_static_graph
def test_cuda_graph_in_first_batch(self):
if not can_use_cuda_graph():
return
startup = paddle.static.Program()
main = paddle.static.Program()
image, label, loss, lr = build_program(main, startup, 1, 10)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
build_strategy = paddle.static.BuildStrategy()
build_strategy.allow_cuda_graph_capture = True
compiled_program = paddle.static.CompiledProgram(
main
).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy, places=place
)
cuda_graph = None
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
image_np = np.random.rand(1, 784).astype('float32')
label_np = np.random.randint(
low=0, high=10, size=[1, 1], dtype='int64'
)
image_t.set(image_np, place)
label_t.set(label_np, place)
# CUDA Graph is not allowed to capture when running the first batch
with self.assertRaises(RuntimeError):
cuda_graph = CUDAGraph(place, mode="global")
cuda_graph.capture_begin()
exe.run(compiled_program)
cuda_graph.capture_end()
if cuda_graph:
cuda_graph.reset()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册