diff --git a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc index e7e925a47797faf6aa2457ca78e62b8a6ee1bef2..73e6664f66f1e04a810e4ed58d13f5b7c05e528e 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc @@ -19,14 +19,27 @@ #include "paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.h" DECLARE_bool(fast_eager_deletion_mode); +DECLARE_bool(new_executor_use_cuda_graph); namespace paddle { namespace framework { bool IsInterpretercoreFastGCEnabled() { - return memory::allocation::AllocatorFacade::Instance() - .IsStreamSafeCUDAAllocatorUsed() && - FLAGS_fast_eager_deletion_mode; + // When using cuda graph, fast GC must be used. Because + // `EventQuery` method in event GC cannot be used in + // cuda graph. + PADDLE_ENFORCE_EQ(memory::allocation::AllocatorFacade::Instance() + .IsStreamSafeCUDAAllocatorUsed() == false && + FLAGS_new_executor_use_cuda_graph, + false, + platform::errors::InvalidArgument( + "When FLAGS_new_executor_use_cuda_graph is true, " + "IsStreamSafeCUDAAllocatorUsed must be true, but " + "got false.")); + return (memory::allocation::AllocatorFacade::Instance() + .IsStreamSafeCUDAAllocatorUsed() && + FLAGS_fast_eager_deletion_mode) || + FLAGS_new_executor_use_cuda_graph; } InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() { diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 4f2a4f48b7f99749963eaa39160360b54a5620d0..63525330ea60debc6db363d96f6049153cd4550a 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -31,6 +31,7 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/device_manager.h" PADDLE_DEFINE_EXPORTED_bool( @@ -50,6 +51,10 @@ PADDLE_DEFINE_EXPORTED_bool(control_flow_use_new_executor, DECLARE_bool(check_nan_inf); DECLARE_bool(benchmark); +DECLARE_bool(new_executor_use_cuda_graph); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +DECLARE_bool(sync_nccl_allreduce); +#endif constexpr const char* kExceptionCaught = "ExceptionCaught"; constexpr const char* kTaskCompletion = "TaskCompletion"; @@ -142,6 +147,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place, } return lhs_prority > rhs_prority; }; + + PrepareForCUDAGraphCapture(); } InterpreterCore::~InterpreterCore() { @@ -161,6 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun( const std::vector& feed_names, const std::vector& feed_tensors) { SetDeviceId(place_); + CheckCUDAGraphBeforeRun(feed_names); Prepare(feed_names, feed_tensors, true); interpreter::CostInfo cost_info; @@ -221,6 +229,7 @@ paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors) { SetDeviceId(place_); + CheckCUDAGraphBeforeRun(feed_names); #ifdef PADDLE_WITH_MKLDNN platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -240,7 +249,16 @@ paddle::framework::FetchList InterpreterCore::Run( // return Fetch Tensors auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName); if (fetch_var) { - return std::move(*fetch_var->GetMutable()); + auto fetch_list = std::move(*fetch_var->GetMutable()); +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_list.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + } +#endif + return fetch_list; } else { return {}; } @@ -249,6 +267,7 @@ paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names, bool need_fetch) { SetDeviceId(place_); + CheckCUDAGraphBeforeRun(feed_names); #ifdef PADDLE_WITH_MKLDNN platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -290,7 +309,16 @@ paddle::framework::FetchList InterpreterCore::Run( HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); if (fetch_var && need_fetch) { - return std::move(*fetch_var->GetMutable()); + auto fetch_list = std::move(*fetch_var->GetMutable()); +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_list.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + } +#endif + return fetch_list; } else { return {}; } @@ -504,6 +532,67 @@ void InterpreterCore::BuildInplace() { } } +void InterpreterCore::PrepareForCUDAGraphCapture() { + if (!FLAGS_new_executor_use_cuda_graph) return; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ( + platform::IsCUDAGraphCapturing(), + false, + platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " + "when running the first batch.")); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_), + true, + platform::errors::InvalidArgument( + "CUDA Graph is only supported on NVIDIA GPU device.")); + // If set true, will call `cudaStreamSynchronize(nccl_stream)`after allreduce. + // which may cause error in cuda graph. This behavior is consistent with PE. + PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce, + false, + platform::errors::InvalidArgument( + "FLAGS_sync_nccl_allreduce must be False to support " + "CUDA Graph capturing.")); + + // All output vars of coalesce_tensor op should not be gc. + // If fused output var of coalesce_tensor is gc, it will cause accuracy + // problem. The specific reasons need to be analyzed. + for (auto& op_desc : block_.AllOps()) { + if (op_desc->Type() == kCoalesceTensor) { + for (auto& out_var_name : op_desc->OutputArgumentNames()) { + execution_config_.skip_gc_vars.insert(out_var_name); + VLOG(4) << "Insert Var(" << out_var_name << ") into skip_gc_vars."; + } + } + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported on NVIDIA GPU device.")); +#endif +} + +void InterpreterCore::CheckCUDAGraphBeforeRun( + const std::vector& feed_names) { +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ( + feed_names.empty(), + true, + platform::errors::InvalidArgument( + "Feeding data is not permitted when capturing CUDA Graph.")); + PADDLE_ENFORCE_EQ( + FLAGS_new_executor_use_cuda_graph, + true, + platform::errors::InvalidArgument( + "You must turn on FLAGS_new_executor_use_cuda_graph to True " + "to enable CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + place_, + platform::CUDAGraphCapturingPlace(), + platform::errors::InvalidArgument("The place to capture CUDAGraph is " + "not the same as the place to run.")); + } +#endif +} + void InterpreterCore::BuildOperatorDependences() { // analysis the dependences between ops, add next_instr_list to each instr, // and set the dependecy_count_ diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 74ff5c563652ea486d552c4d1ecf2cbb363fa04d..53625c87938305c6a22909d70352d0cb1095b1d0 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -97,6 +97,10 @@ class InterpreterCore { const std::vector>& input_var2op, size_t var_index); void SetFeedVarsInplaceSkip(const std::vector& feed_names); + // cuda graph + void CheckCUDAGraphBeforeRun(const std::vector& feed_names); + void PrepareForCUDAGraphCapture(); + // execution void RunImpl(); void ExecuteInstructionList(const std::vector& vec_instr); diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index 7a5acb762eb83bbc52254cc2427938a1c8f0ba39..2f14a23168533cfdf34072b30a26b186d039d2c1 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -18,6 +18,7 @@ #include "paddle/phi/backends/all_context.h" DECLARE_bool(use_stream_safe_cuda_allocator); +DECLARE_bool(new_executor_use_cuda_graph); namespace paddle { namespace platform { @@ -43,7 +44,10 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, auto stream = dev_ctx->stream(); CUDAGraph::BeginCapture(place, stream, mode); - auto old_value = FLAGS_use_stream_safe_cuda_allocator; + // When using cuda graph in new executor, fast GC must be used. + // FLAGS_use_stream_safe_cuda_allocator should be true. + auto old_value = FLAGS_use_stream_safe_cuda_allocator && + !FLAGS_new_executor_use_cuda_graph; if (old_value) { FLAGS_use_stream_safe_cuda_allocator = false; } diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index 5b8dc47d6498177124bdc53be2d27f97546fe086..43da2ecb7bb6e1ff532020bd636c80276cfa128d 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1010,6 +1010,18 @@ PADDLE_DEFINE_EXPORTED_bool(enable_cinn_auto_tune, #endif +/* + * CUDA Graph related FLAG + * Name: FLAGS_new_executor_use_cuda_graph + * Since Version: 2.4 + * Value Range: bool, default=false + * Example: FLAGS_new_executor_use_cuda_graph=true would allow + * new executor to use CUDA Graph. + */ +PADDLE_DEFINE_EXPORTED_bool(new_executor_use_cuda_graph, + false, + "Use CUDA Graph in new executor"); + DEFINE_int32(record_pool_max_size, 2000000, "SlotRecordDataset slot record pool max size"); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d5db9a7f72c0de77c7f8ee6a3467bf6f721bcf01..e3376d8446586607947bf2143fe5c9fe32115dac 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -26,6 +26,7 @@ from .framework import convert_np_dtype_to_dtype_, _apply_pass from . import core from . import unique_name from . import compiler +from . import set_flags from .trainer_factory import TrainerFactory from .trainer_factory import FetchHandlerMonitor import copy @@ -510,6 +511,16 @@ def _is_dy2st_enable_standalone_executor(): ] +def _is_cuda_graph_enable_standalone_executor(): + return framework._cuda_graph_enable_standalone_executor_ in [ + 1, + '1', + True, + 'True', + 'true', + ] + + def _prepare_fleet_executor(): from ..distributed.fleet.proto import fleet_executor_desc_pb2 @@ -844,7 +855,19 @@ class _ExecutorCache: ) build_strategy = compiled_program._build_strategy # print(f"Program before convert:\n {inner_program}", flush=True) + use_cuda_graph = False + # When using cuda graph, the cuda graph preparation logic in PE is not + # executed, but it is processed in the constructor of new executor. + if ( + build_strategy is not None + and build_strategy.allow_cuda_graph_capture + ): + use_cuda_graph = True + build_strategy.allow_cuda_graph_capture = False + set_flags({"FLAGS_new_executor_use_cuda_graph": True}) compiled_program._compile(scope, place) + if use_cuda_graph: + build_strategy.allow_cuda_graph_capture = True ir_graph = framework.IrGraph(compiled_program._graph) converted_program = ir_graph.to_program() @@ -1746,24 +1769,25 @@ class Executor: ) return False - # Unsupported case 4: CUDA Graph + # Unsupported case 4: async mode if ( compiled_program._build_strategy is not None - and compiled_program._build_strategy.allow_cuda_graph_capture + and compiled_program._build_strategy.async_mode ): warnings.warn( - "Standalone executor is not used for CUDA Graph", + "Standalone executor is not used for async mode", UserWarning, ) return False - # Unsupported case 5: async mode + # Unsupported case 5: CUDA Graph if ( compiled_program._build_strategy is not None - and compiled_program._build_strategy.async_mode + and compiled_program._build_strategy.allow_cuda_graph_capture + and not _is_cuda_graph_enable_standalone_executor() ): warnings.warn( - "Standalone executor is not used for async mode", + "Standalone executor is not used for CUDA Graph when FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=0", UserWarning, ) return False @@ -1811,8 +1835,13 @@ class Executor: tensor = core.get_variable_tensor(scope, lr_sheduler._var_name) # NOTE(dev): `tensor.set(data, self.place)` always call TensorCopySync that is a blocking behavior. So we use `_copy_from` to replace it. cpu_tensor = _as_lodtensor(data, core.CPUPlace()) - # for ipu, tensor is allocated on cpu - if core.is_compiled_with_ipu(): + if core.is_cuda_graph_capturing(): + warnings.warn( + "Caution!!! When capturing CUDA Graph, the learning rate scheduler would not " + "take any effect! Please set the learning rate manually before each batch!" + ) + elif core.is_compiled_with_ipu(): + # for ipu, tensor is allocated on cpu tensor._copy_from(cpu_tensor, tensor._place()) else: tensor._copy_from(cpu_tensor, self.place) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 41b9b8bbb2deb74882180263e94cb3eb70709dc9..1245d0d28e65cc8d02a83d429641a4c077250a2b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -86,6 +86,9 @@ _enable_standalone_executor_ = os.environ.get( _dy2st_enable_standalone_executor_ = os.environ.get( 'FLAGS_DY2ST_USE_STANDALONE_EXECUTOR', 1 ) +_cuda_graph_enable_standalone_executor_ = os.environ.get( + 'FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR', 0 +) # Some explanation of our execution system 2022.03 # For now we have 3 kinds of execution system, since we refactored dygraph mode to diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6d99deb2bfc58019de279f3b41a8c910a1aab5a1..2eea2070befe39a230c4ac77306983abe0be4ca9 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1259,3 +1259,7 @@ set_tests_properties(test_parallel_executor_dry_run PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0") set_tests_properties(test_parallel_executor_drop_scope PROPERTIES ENVIRONMENT "FLAGS_USE_STANDALONE_EXECUTOR=0") + +set_tests_properties( + test_cuda_graph_static_mode + PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1") diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py index d8ba91bad7b8a530fe3137b00b670f5e0878cf04..edfa7665882ca6e8180a5302114b2b5972585f9b 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -18,18 +18,16 @@ import shutil import unittest import numpy as np -from simple_nets import simple_fc_net_with_inputs import paddle from paddle.device.cuda.graphs import CUDAGraph -from paddle.fluid.dygraph.base import switch_to_static_graph def can_use_cuda_graph(): return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() -class TestCUDAGraph(unittest.TestCase): +class TestCUDAGraphInDygraphMode(unittest.TestCase): def setUp(self): if can_use_cuda_graph(): paddle.set_flags( @@ -46,94 +44,6 @@ class TestCUDAGraph(unittest.TestCase): np.random.randint(low=0, high=10, size=shape).astype("float32") ) - @switch_to_static_graph - def test_cuda_graph_static_graph(self): - if not can_use_cuda_graph(): - return - - seed = 100 - loss_cuda_graph = self.cuda_graph_static_graph_main( - seed, use_cuda_graph=True - ) - loss_no_cuda_graph = self.cuda_graph_static_graph_main( - seed, use_cuda_graph=False - ) - self.assertEqual(loss_cuda_graph, loss_no_cuda_graph) - - def cuda_graph_static_graph_main(self, seed, use_cuda_graph): - batch_size = 1 - class_num = 10 - image_shape = [batch_size, 784] - label_shape = [batch_size, 1] - - paddle.seed(seed) - np.random.seed(seed) - startup = paddle.static.Program() - main = paddle.static.Program() - with paddle.static.program_guard(main, startup): - image = paddle.static.data( - name="image", shape=image_shape, dtype='float32' - ) - label = paddle.static.data( - name="label", shape=label_shape, dtype='int64' - ) - image.persistable = True - label.persistable = True - loss = simple_fc_net_with_inputs(image, label, class_num) - loss.persistable = True - lr = paddle.optimizer.lr.PiecewiseDecay( - boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04] - ) - optimizer = paddle.optimizer.SGD(learning_rate=lr) - optimizer.minimize(loss) - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - scope = paddle.static.Scope() - with paddle.static.scope_guard(scope): - exe.run(startup) - build_strategy = paddle.static.BuildStrategy() - build_strategy.allow_cuda_graph_capture = True - build_strategy.fix_op_run_order = True - build_strategy.fuse_all_optimizer_ops = True - compiled_program = paddle.static.CompiledProgram( - main - ).with_data_parallel( - loss_name=loss.name, build_strategy=build_strategy, places=place - ) - image_t = scope.var(image.name).get_tensor() - label_t = scope.var(label.name).get_tensor() - loss_t = scope.var(loss.name).get_tensor() - lr_var = main.global_block().var(lr._var_name) - self.assertTrue(lr_var.persistable) - lr_t = scope.var(lr_var.name).get_tensor() - cuda_graph = None - for batch_id in range(20): - image_t.set( - np.random.rand(*image_shape).astype('float32'), place - ) - label_t.set( - np.random.randint( - low=0, high=class_num, size=label_shape, dtype='int64' - ), - place, - ) - - if batch_id == 1 and use_cuda_graph: - cuda_graph = CUDAGraph(place, mode="global") - cuda_graph.capture_begin() - exe.run(compiled_program) - cuda_graph.capture_end() - - if cuda_graph: - lr_t.set(np.array([lr()], dtype='float32'), place) - cuda_graph.replay() - else: - exe.run(compiled_program) - lr.step() - if cuda_graph: - cuda_graph.reset() - return np.array(loss_t) - def test_cuda_graph_dynamic_graph(self): if not can_use_cuda_graph(): return diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..e159334c87a6492e50d51d057c6c5ab8513a9e96 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from simple_nets import simple_fc_net_with_inputs + +import paddle +from paddle.device.cuda.graphs import CUDAGraph +from paddle.fluid.dygraph.base import switch_to_static_graph + + +def can_use_cuda_graph(): + return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() + + +class TestCUDAGraphInStaticMode(unittest.TestCase): + def setUp(self): + if can_use_cuda_graph(): + # The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static + # mode is inconsistent with that in dygraph mode. + # In static mode, FLAGS_use_stream_safe_cuda_allocator must be True. + # In dygraph mode, FLAGS_use_stream_safe_cuda_allocator must be False. + # These two types of unittests need to be written separately, because + # the allocator may only be initialized once, and the flag + # `FLAGS_use_stream_safe_cuda_allocator` only takes effect during + # initialization. + paddle.set_flags( + { + 'FLAGS_allocator_strategy': 'auto_growth', + 'FLAGS_sync_nccl_allreduce': False, + 'FLAGS_cudnn_deterministic': True, + 'FLAGS_use_stream_safe_cuda_allocator': True, + } + ) + + @switch_to_static_graph + def test_cuda_graph_static_graph(self): + if not can_use_cuda_graph(): + return + + seed = 100 + loss_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=True + ) + loss_no_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=False + ) + self.assertEqual(loss_cuda_graph, loss_no_cuda_graph) + + def cuda_graph_static_graph_main(self, seed, use_cuda_graph): + batch_size = 1 + class_num = 10 + image_shape = [batch_size, 784] + label_shape = [batch_size, 1] + + paddle.seed(seed) + np.random.seed(seed) + startup = paddle.static.Program() + main = paddle.static.Program() + with paddle.static.program_guard(main, startup): + image = paddle.static.data( + name="image", shape=image_shape, dtype='float32' + ) + label = paddle.static.data( + name="label", shape=label_shape, dtype='int64' + ) + image.persistable = True + label.persistable = True + loss = simple_fc_net_with_inputs(image, label, class_num) + loss.persistable = True + lr = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04] + ) + optimizer = paddle.optimizer.SGD(learning_rate=lr) + optimizer.minimize(loss) + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup) + build_strategy = paddle.static.BuildStrategy() + build_strategy.allow_cuda_graph_capture = True + build_strategy.fix_op_run_order = True + build_strategy.fuse_all_optimizer_ops = True + compiled_program = paddle.static.CompiledProgram( + main + ).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy, places=place + ) + image_t = scope.var(image.name).get_tensor() + label_t = scope.var(label.name).get_tensor() + loss_t = scope.var(loss.name).get_tensor() + lr_var = main.global_block().var(lr._var_name) + self.assertTrue(lr_var.persistable) + lr_t = scope.var(lr_var.name).get_tensor() + cuda_graph = None + for batch_id in range(20): + image_t.set( + np.random.rand(*image_shape).astype('float32'), place + ) + label_t.set( + np.random.randint( + low=0, high=class_num, size=label_shape, dtype='int64' + ), + place, + ) + + if batch_id == 1 and use_cuda_graph: + cuda_graph = CUDAGraph(place, mode="global") + cuda_graph.capture_begin() + exe.run(compiled_program) + cuda_graph.capture_end() + + if cuda_graph: + lr_t.set(np.array([lr()], dtype='float32'), place) + cuda_graph.replay() + else: + exe.run(compiled_program) + lr.step() + if cuda_graph: + cuda_graph.reset() + return np.array(loss_t) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 0a898caa3fb2f38791ca2e73215674d7d95ceefa..f5c57a312d84583bcf367ccf85666ce80363a2ec 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -623,6 +623,7 @@ HIGH_PARALLEL_JOB_NEW = [ 'test_dataset_consistency_inspection', 'test_cuda_empty_cache', 'test_cuda_graph', + 'test_cuda_graph_static_mode', 'test_disable_signal_handler', 'test_eig_op', 'test_eigh_op', @@ -2509,6 +2510,7 @@ TETRAD_PARALLEL_JOB = [ 'test_dlpack', 'test_complex_variable', 'test_cuda_graph', + 'test_cuda_graph_static_mode', 'test_custom_grad_input', 'test_accuracy_op', 'test_pool1d_api',