From 9167fda39661d3c04151c004e2c7a4e6f97db5d1 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 20 Feb 2023 11:04:28 +0800 Subject: [PATCH] fix cuda graph error when new executor change feed fetch (#50306) * change error * fix --- .../framework/new_executor/interpretercore.cc | 20 +++-- python/paddle/fluid/compiler.py | 4 + .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/test_cuda_graph_static_mode.py | 79 ++++++++++++------ .../test_cuda_graph_static_mode_error.py | 83 +++++++++++++++++++ 5 files changed, 157 insertions(+), 32 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode_error.py diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 8845cbd571..8514296559 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -531,11 +531,6 @@ void InterpreterCore::BuildInplace() { void InterpreterCore::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; #ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_EQ( - platform::IsCUDAGraphCapturing(), - false, - platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " - "when running the first batch.")); PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_), true, platform::errors::InvalidArgument( @@ -548,14 +543,23 @@ void InterpreterCore::PrepareForCUDAGraphCapture() { "FLAGS_sync_nccl_allreduce must be False to support " "CUDA Graph capturing.")); - // All output vars of coalesce_tensor op should not be gc. + // All output vars of coalesce_tensor op should be persistable. // If fused output var of coalesce_tensor is gc, it will cause accuracy // problem. The specific reasons need to be analyzed. for (auto& op_desc : block_.AllOps()) { if (op_desc->Type() == kCoalesceTensor) { for (auto& out_var_name : op_desc->OutputArgumentNames()) { - execution_config_.skip_gc_vars.insert(out_var_name); - VLOG(4) << "Insert Var(" << out_var_name << ") into skip_gc_vars."; + // The fused var needs to be set to persistable, not just added to + // skip_gc_vars. + // In the case where the feed fetch var is changed, StandaloneExecutor + // will be newly constructed. If the fused var is not persistable, + // these vars will be recreated and initialized, resulting in + // precision problems. + auto* out_var = op_desc->Block()->FindVarRecursive(out_var_name); + if (out_var) { + out_var->SetPersistable(true); + VLOG(4) << "Mark Var(" << out_var_name << ") as Persistable."; + } } } } diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 609bfa3d93..080c40b3ba 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -484,6 +484,10 @@ class CompiledProgram: self._persistable_vars = list(set(self._persistable_vars)) self._persistable_vars.sort() + if core.is_cuda_graph_capturing(): + raise RuntimeError( + "CUDA Graph is not allowed to capture when running the first batch." + ) return core.ParallelExecutor( places, self._persistable_vars, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 84b4f2ee5c..7922505dc5 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1230,3 +1230,6 @@ set_tests_properties(test_parallel_executor_drop_scope set_tests_properties( test_cuda_graph_static_mode PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1") +set_tests_properties( + test_cuda_graph_static_mode_error + PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1") diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py index e159334c87..3dc56cc703 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py @@ -26,8 +26,31 @@ def can_use_cuda_graph(): return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() +def build_program(main, startup, batch_size, class_num): + image_shape = [batch_size, 784] + label_shape = [batch_size, 1] + with paddle.static.program_guard(main, startup): + image = paddle.static.data( + name="image", shape=image_shape, dtype='float32' + ) + label = paddle.static.data( + name="label", shape=label_shape, dtype='int64' + ) + image.persistable = True + label.persistable = True + loss = simple_fc_net_with_inputs(image, label, class_num) + loss.persistable = True + lr = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04] + ) + optimizer = paddle.optimizer.SGD(learning_rate=lr) + optimizer.minimize(loss) + return image, label, loss, lr + + class TestCUDAGraphInStaticMode(unittest.TestCase): def setUp(self): + self.init_data() if can_use_cuda_graph(): # The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static # mode is inconsistent with that in dygraph mode. @@ -46,6 +69,9 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): } ) + def init_data(self): + self.use_feed_data = False + @switch_to_static_graph def test_cuda_graph_static_graph(self): if not can_use_cuda_graph(): @@ -70,22 +96,11 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): np.random.seed(seed) startup = paddle.static.Program() main = paddle.static.Program() - with paddle.static.program_guard(main, startup): - image = paddle.static.data( - name="image", shape=image_shape, dtype='float32' - ) - label = paddle.static.data( - name="label", shape=label_shape, dtype='int64' - ) - image.persistable = True - label.persistable = True - loss = simple_fc_net_with_inputs(image, label, class_num) - loss.persistable = True - lr = paddle.optimizer.lr.PiecewiseDecay( - boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04] - ) - optimizer = paddle.optimizer.SGD(learning_rate=lr) - optimizer.minimize(loss) + + image, label, loss, lr = build_program( + main, startup, batch_size, class_num + ) + place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) scope = paddle.static.Scope() @@ -108,15 +123,16 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): lr_t = scope.var(lr_var.name).get_tensor() cuda_graph = None for batch_id in range(20): - image_t.set( - np.random.rand(*image_shape).astype('float32'), place + use_feed_data = ( + True if batch_id == 0 and self.use_feed_data else False ) - label_t.set( - np.random.randint( - low=0, high=class_num, size=label_shape, dtype='int64' - ), - place, + image_np = np.random.rand(*image_shape).astype('float32') + label_np = np.random.randint( + low=0, high=class_num, size=label_shape, dtype='int64' ) + if not use_feed_data: + image_t.set(image_np, place) + label_t.set(label_np, place) if batch_id == 1 and use_cuda_graph: cuda_graph = CUDAGraph(place, mode="global") @@ -128,12 +144,27 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): lr_t.set(np.array([lr()], dtype='float32'), place) cuda_graph.replay() else: - exe.run(compiled_program) + if use_feed_data: + exe.run( + compiled_program, + feed={'image': image_np, 'label': label_np}, + ) + else: + exe.run(compiled_program) lr.step() if cuda_graph: cuda_graph.reset() return np.array(loss_t) +class TestCUDAGraphWhenFeedDataChanges(TestCUDAGraphInStaticMode): + def init_data(self): + # When feed fetch var of new executor changes, a new + # StandaloneExecutor will be newly created. And the + # behavior of capturing cuda graph will change. + # Add test for this case. + self.use_feed_data = True + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode_error.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode_error.py new file mode 100644 index 0000000000..9cf0945992 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode_error.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_cuda_graph_static_mode import build_program, can_use_cuda_graph + +import paddle +from paddle.device.cuda.graphs import CUDAGraph +from paddle.fluid.dygraph.base import switch_to_static_graph + + +class TestCUDAGraphInFirstBatch(unittest.TestCase): + def setUp(self): + if can_use_cuda_graph(): + paddle.set_flags( + { + 'FLAGS_allocator_strategy': 'auto_growth', + 'FLAGS_sync_nccl_allreduce': False, + 'FLAGS_cudnn_deterministic': True, + 'FLAGS_use_stream_safe_cuda_allocator': True, + } + ) + + @switch_to_static_graph + def test_cuda_graph_in_first_batch(self): + if not can_use_cuda_graph(): + return + + startup = paddle.static.Program() + main = paddle.static.Program() + + image, label, loss, lr = build_program(main, startup, 1, 10) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup) + build_strategy = paddle.static.BuildStrategy() + build_strategy.allow_cuda_graph_capture = True + compiled_program = paddle.static.CompiledProgram( + main + ).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy, places=place + ) + + cuda_graph = None + + image_t = scope.var(image.name).get_tensor() + label_t = scope.var(label.name).get_tensor() + image_np = np.random.rand(1, 784).astype('float32') + label_np = np.random.randint( + low=0, high=10, size=[1, 1], dtype='int64' + ) + image_t.set(image_np, place) + label_t.set(label_np, place) + + # CUDA Graph is not allowed to capture when running the first batch + with self.assertRaises(RuntimeError): + cuda_graph = CUDAGraph(place, mode="global") + cuda_graph.capture_begin() + exe.run(compiled_program) + cuda_graph.capture_end() + + if cuda_graph: + cuda_graph.reset() + + +if __name__ == "__main__": + unittest.main() -- GitLab