未验证 提交 ba6dcbc9 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] support get & set random seed (#55659)

上级 a88d36aa
...@@ -36,6 +36,16 @@ void *PD_DeviceContextAllocateTensor(const PD_DeviceContext *ctx, ...@@ -36,6 +36,16 @@ void *PD_DeviceContextAllocateTensor(const PD_DeviceContext *ctx,
PD_DataType dtype, PD_DataType dtype,
PD_Status *status); PD_Status *status);
void PD_DeviceContextSetSeed(const PD_DeviceContext *ctx,
uint64_t seed,
PD_Status *status);
uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext *ctx,
PD_Status *status);
uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext *ctx,
PD_Status *status);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -298,6 +298,26 @@ class DeviceContext : public WrapperBase<PD_DeviceContext> { ...@@ -298,6 +298,26 @@ class DeviceContext : public WrapperBase<PD_DeviceContext> {
PD_CHECK_STATUS(status); PD_CHECK_STATUS(status);
return static_cast<T*>(ptr); return static_cast<T*>(ptr);
} }
uint64_t seed() const {
C_Status status;
auto seed_val = PD_DeviceContextGetSeed(raw_data(), &status);
PD_CHECK_STATUS(status);
return seed_val;
}
void seed(uint64_t seed_val) const {
C_Status status;
PD_DeviceContextSetSeed(raw_data(), seed_val, &status);
PD_CHECK_STATUS(status);
}
uint64_t random() const {
C_Status status;
auto rand_val = PD_DeviceContextGetRandom(raw_data(), &status);
PD_CHECK_STATUS(status);
return rand_val;
}
}; };
class Scalar : public WrapperBase<PD_Scalar> { class Scalar : public WrapperBase<PD_Scalar> {
......
...@@ -74,4 +74,32 @@ void* PD_DeviceContextAllocateTensor(const PD_DeviceContext* ctx, ...@@ -74,4 +74,32 @@ void* PD_DeviceContextAllocateTensor(const PD_DeviceContext* ctx,
} }
} }
void PD_DeviceContextSetSeed(const PD_DeviceContext* ctx,
uint64_t seed,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
dev_ctx->GetGenerator()->SetCurrentSeed(seed);
}
uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext* ctx,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
return dev_ctx->GetGenerator()->GetCurrentSeed();
}
uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext* ctx,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
return dev_ctx->GetGenerator()->Random64();
}
PD_REGISTER_CAPI(device_context); PD_REGISTER_CAPI(device_context);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册