未验证 提交 ba6dcbc9 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] support get & set random seed (#55659)

上级 a88d36aa
......@@ -36,6 +36,16 @@ void *PD_DeviceContextAllocateTensor(const PD_DeviceContext *ctx,
PD_DataType dtype,
PD_Status *status);
void PD_DeviceContextSetSeed(const PD_DeviceContext *ctx,
uint64_t seed,
PD_Status *status);
uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext *ctx,
PD_Status *status);
uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext *ctx,
PD_Status *status);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -298,6 +298,26 @@ class DeviceContext : public WrapperBase<PD_DeviceContext> {
PD_CHECK_STATUS(status);
return static_cast<T*>(ptr);
}
uint64_t seed() const {
C_Status status;
auto seed_val = PD_DeviceContextGetSeed(raw_data(), &status);
PD_CHECK_STATUS(status);
return seed_val;
}
void seed(uint64_t seed_val) const {
C_Status status;
PD_DeviceContextSetSeed(raw_data(), seed_val, &status);
PD_CHECK_STATUS(status);
}
uint64_t random() const {
C_Status status;
auto rand_val = PD_DeviceContextGetRandom(raw_data(), &status);
PD_CHECK_STATUS(status);
return rand_val;
}
};
class Scalar : public WrapperBase<PD_Scalar> {
......
......@@ -74,4 +74,32 @@ void* PD_DeviceContextAllocateTensor(const PD_DeviceContext* ctx,
}
}
void PD_DeviceContextSetSeed(const PD_DeviceContext* ctx,
uint64_t seed,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
dev_ctx->GetGenerator()->SetCurrentSeed(seed);
}
uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext* ctx,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
return dev_ctx->GetGenerator()->GetCurrentSeed();
}
uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext* ctx,
PD_Status* status) {
if (status) {
*status = C_SUCCESS;
}
auto dev_ctx = reinterpret_cast<const phi::DeviceContext*>(ctx);
return dev_ctx->GetGenerator()->Random64();
}
PD_REGISTER_CAPI(device_context);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册