gaussian_random_op.cu 1.3 KB
Newer Older
1 2 3 4 5
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"

D
dongzhihong 已提交
6 7
#include "paddle/framework/op_registry.h"

D
dongzhihong 已提交
8 9
namespace paddle {
namespace operators {
D
dongzhihong 已提交
10 11

template <typename T>
12
class GaussianRandomKernel : public framework::OpKernel {
D
dongzhihong 已提交
13
 public:
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
  void Compute(const framework::ExecutionContext& context) const override {
    T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
    T std = static_cast<T>(context.op_.GetAttr<T>("std"));
    auto* tensor = context.Output<framework::Tensor>(0);
    T* data = tensor->mutable_data<T>(context.GetPlace());

    int seed = context.op_.GetAttr<int>("seed");
    if (seed == 0) {
      seed = std::random_device()();
    }
    curandGenerator_t g;
    PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
        &g, CURAND_RNG_PSEUDO_DEFAULT));
    PADDLE_ENFORCE(
        platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
    // auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
    curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
                         std);
D
dongzhihong 已提交
32 33
  }
};
D
dongzhihong 已提交
34

D
dongzhihong 已提交
35 36 37
}  // namespace operators
}  // namespace paddle

38 39
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);