#pragma once #include #include "glog/logging.h" #include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { namespace operators { template bool Gaussian(platform::CPUDeviceContext* ctx, T* output, const int size, const T& mean, const T& std, const T& seed) { auto g = ctx->RandGenerator(seed); std::normal_distribution distribution(mean, std); for (int i = 0; i < size; ++i) { output[i] = distribution(g); } return true; } #ifndef PADDLE_ONLY_CPU template bool Gaussian(platform::CUDADeviceContext* ctx, T* output, const int size, const T& mean, const T& std, const T& seed) { auto g = ctx->RandGenerator(seed); return curandGenerateNormal(g, output, size, mean, std); } #endif template class RandomOpKernel : public framework::OpKernel { public: void Compute(const framework::KernelContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); auto seed = context.op_.GetAttr("seed"); auto* output = context.Output(0)->GetMutable(); auto place = context.GetPlace(); if (platform::is_cpu_place(place)) { Gaussian( dynamic_cast(context.device_context_), output->mutable_data(context.GetPlace()), framework::product(output->dims()), mean, std, seed); } else { Gaussian( dynamic_cast(context.device_context_), output->mutable_data(context.GetPlace()), framework::product(output->dims()), mean, std, seed); } } }; } // namespace operators } // namespace paddle