未验证 提交 94aea284 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] pass global seed to CINN (#52078)

* [CINN] pass global seed to CINN

* fix cu not include cinn/runtime/flags.h bug

* fix DefaultCUDAGenerator should has device id bug
上级 929892c3
......@@ -21,6 +21,7 @@
#include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/flags.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/generator.h"
DECLARE_bool(cudnn_deterministic);
......@@ -86,6 +87,12 @@ void SetCinnRuntimeFlags() {
::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic);
}
template <>
void SetCinnRandomSeed<phi::CPUContext>() {
auto seed = phi::DefaultCPUGenerator()->GetCurrentSeed();
::cinn::runtime::RandomSeed::GetOrSet(seed);
}
} // namespace details
class CinnLaunchOp : public framework::OperatorWithKernel {
......
......@@ -14,7 +14,24 @@ limitations under the License. */
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "cinn/runtime/flags.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/generator.h"
namespace paddle {
namespace operators {
namespace details {
template <>
void SetCinnRandomSeed<phi::GPUContext>() {
auto seed = phi::DefaultCUDAGenerator(0)->GetCurrentSeed();
::cinn::runtime::RandomSeed::GetOrSet(seed);
}
} // namespace details
} // namespace operators
} // namespace paddle
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(
......
......@@ -54,6 +54,10 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
// set CINN global random seed
template <typename DeviceContext>
void SetCinnRandomSeed();
} // namespace details
template <typename DeviceContext, typename T>
......@@ -133,6 +137,9 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// Step 3. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags();
// set CINN global random seed
details::SetCinnRandomSeed<DeviceContext>();
// Step 4. Execute the compiled CINN instructions by a PE or
// by the CINN compiled program in sequential order
if (FLAGS_enable_pe_launch_cinn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册