未验证 提交 ea47d211 编写于 作者: Y Yiqun Liu 提交者: GitHub

Make FLAGS_determinstic effective in conv2d forward. (#37173)

* Make FLAGS_determinstic effective in conv2d forward.

* Add call of SetCinnCudnnDeterministic in cinn_launch op.
上级 5091fed7
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
DECLARE_bool(cudnn_deterministic);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -67,6 +69,12 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, ...@@ -67,6 +69,12 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
compiled_obj.runtime_program->Execute(&context.FinalizeArguments()); compiled_obj.runtime_program->Execute(&context.FinalizeArguments());
} }
void SetCinnRuntimeFlags() {
VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to "
<< FLAGS_cudnn_deterministic;
::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic);
}
CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj) CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj)
: paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap), : paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
cinn_scope_(compiled_obj.scope) { cinn_scope_(compiled_obj.scope) {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h" #include "cinn/hlir/framework/scope.h"
#include "cinn/runtime/cinn_runtime.h" #include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/flags.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -110,6 +111,9 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result); ...@@ -110,6 +111,9 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result);
// Launch cinn to execute compiled executable program and wait done // Launch cinn to execute compiled executable program and wait done
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
const CinnLaunchContext& context); const CinnLaunchContext& context);
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
} // namespace details } // namespace details
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -202,7 +206,10 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -202,7 +206,10 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
launch_context->AssignInternalVariable(var_name, tensor); launch_context->AssignInternalVariable(var_name, tensor);
} }
// Step 4. Launch CINN to execute the compiled executable program // Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags();
// Step 5. Launch CINN to execute the compiled executable program
details::LaunchCinnExecution(cinn_compiled_object, *launch_context); details::LaunchCinnExecution(cinn_compiled_object, *launch_context);
VLOG(4) << "CinnLaunchOp launch execution done."; VLOG(4) << "CinnLaunchOp launch execution done.";
} }
......
...@@ -298,11 +298,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -298,11 +298,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
using search = SearchAlgorithm<miopenConvFwdAlgorithm_t>; using search = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search::GetWorkspaceSize(args); workspace_size = search::GetWorkspaceSize(args);
algo = search::Find<T>(args, exhaustive_search, false, workspace_size, ctx); algo = search::Find<T>(args, exhaustive_search, deterministic,
workspace_size, ctx);
#else #else
cudnnConvolutionFwdAlgo_t algo{}; cudnnConvolutionFwdAlgo_t algo{};
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, ctx); algo = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, algo); workspace_size = search::GetWorkspaceSize(args, algo);
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册