From ea47d211d3a720b5b5c3c1c346b10c204bb832f3 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 16 Nov 2021 11:03:48 +0800 Subject: [PATCH] Make FLAGS_determinstic effective in conv2d forward. (#37173) * Make FLAGS_determinstic effective in conv2d forward. * Add call of SetCinnCudnnDeterministic in cinn_launch op. --- paddle/fluid/operators/cinn_launch_op.cc | 8 ++++++++ paddle/fluid/operators/cinn_launch_op.h | 9 ++++++++- paddle/fluid/operators/conv_cudnn_op.cu | 5 +++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/cinn_launch_op.cc b/paddle/fluid/operators/cinn_launch_op.cc index 3b3e21ed1e1..51c5183241a 100644 --- a/paddle/fluid/operators/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn_launch_op.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/string/string_helper.h" +DECLARE_bool(cudnn_deterministic); + namespace paddle { namespace operators { @@ -67,6 +69,12 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, compiled_obj.runtime_program->Execute(&context.FinalizeArguments()); } +void SetCinnRuntimeFlags() { + VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to " + << FLAGS_cudnn_deterministic; + ::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic); +} + CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj) : paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap), cinn_scope_(compiled_obj.scope) { diff --git a/paddle/fluid/operators/cinn_launch_op.h b/paddle/fluid/operators/cinn_launch_op.h index 4e1a05a7a32..348d1dda027 100644 --- a/paddle/fluid/operators/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn_launch_op.h @@ -21,6 +21,7 @@ #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/scope.h" #include "cinn/runtime/cinn_runtime.h" +#include "cinn/runtime/flags.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -110,6 +111,9 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result); // Launch cinn to execute compiled executable program and wait done void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, const CinnLaunchContext& context); + +// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS. +void SetCinnRuntimeFlags(); } // namespace details template @@ -202,7 +206,10 @@ class CinnLaunchOpKernel : public framework::OpKernel { launch_context->AssignInternalVariable(var_name, tensor); } - // Step 4. Launch CINN to execute the compiled executable program + // Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic. + details::SetCinnRuntimeFlags(); + + // Step 5. Launch CINN to execute the compiled executable program details::LaunchCinnExecution(cinn_compiled_object, *launch_context); VLOG(4) << "CinnLaunchOp launch execution done."; } diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index c49a3ee1c20..275e81fc7f3 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -298,11 +298,12 @@ class CUDNNConvOpKernel : public framework::OpKernel { miopenConvFwdAlgorithm_t algo{}; using search = SearchAlgorithm; workspace_size = search::GetWorkspaceSize(args); - algo = search::Find(args, exhaustive_search, false, workspace_size, ctx); + algo = search::Find(args, exhaustive_search, deterministic, + workspace_size, ctx); #else cudnnConvolutionFwdAlgo_t algo{}; using search = SearchAlgorithm; - algo = search::Find(args, exhaustive_search, false, ctx); + algo = search::Find(args, exhaustive_search, deterministic, ctx); workspace_size = search::GetWorkspaceSize(args, algo); #endif -- GitLab