未验证 提交 ea47d211 编写于 作者: Y Yiqun Liu 提交者: GitHub

Make FLAGS_determinstic effective in conv2d forward. (#37173)

* Make FLAGS_determinstic effective in conv2d forward.

* Add call of SetCinnCudnnDeterministic in cinn_launch op.
上级 5091fed7
......@@ -15,6 +15,8 @@
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(cudnn_deterministic);
namespace paddle {
namespace operators {
......@@ -67,6 +69,12 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
compiled_obj.runtime_program->Execute(&context.FinalizeArguments());
}
void SetCinnRuntimeFlags() {
VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to "
<< FLAGS_cudnn_deterministic;
::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic);
}
CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj)
: paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
cinn_scope_(compiled_obj.scope) {
......
......@@ -21,6 +21,7 @@
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/flags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -110,6 +111,9 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result);
// Launch cinn to execute compiled executable program and wait done
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
const CinnLaunchContext& context);
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
} // namespace details
template <typename DeviceContext, typename T>
......@@ -202,7 +206,10 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
launch_context->AssignInternalVariable(var_name, tensor);
}
// Step 4. Launch CINN to execute the compiled executable program
// Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags();
// Step 5. Launch CINN to execute the compiled executable program
details::LaunchCinnExecution(cinn_compiled_object, *launch_context);
VLOG(4) << "CinnLaunchOp launch execution done.";
}
......
......@@ -298,11 +298,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
miopenConvFwdAlgorithm_t algo{};
using search = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search::GetWorkspaceSize(args);
algo = search::Find<T>(args, exhaustive_search, false, workspace_size, ctx);
algo = search::Find<T>(args, exhaustive_search, deterministic,
workspace_size, ctx);
#else
cudnnConvolutionFwdAlgo_t algo{};
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, ctx);
algo = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册