#include "paddle/operators/random_op.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { template class GaussianRandomOpKernel : public framework::OpKernel { public: void Compute(const framework::KernelContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); auto* output = context.Output(0)->GetMutable(); T* r = output->mutable_data(context.GetPlace()); auto ctx = static_cast (context.device_context_); // generator need to modify context auto g = const_cast(ctx)->RandGenerator(); curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); } }; } // namespace operators } // namespace paddle typedef paddle::operators::GaussianRandomOpKernel RandomOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);