random_op.h 1.9 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {
D
dongzhihong 已提交
9

10
template <typename T>
D
dongzhihong 已提交
11 12
bool Gaussian(platform::CPUDeviceContext* ctx,
              T* output,
13 14 15
              const int size,
              const T& mean,
              const T& std,
D
dongzhihong 已提交
16
              const T& seed) {
D
dongzhihong 已提交
17 18
  auto g = ctx->RandGenerator(seed);
  std::normal_distribution<T> distribution(mean, std);
D
dongzhihong 已提交
19 20 21 22 23 24 25
  for (int i = 0; i < size; ++i) {
    output[i] = distribution(g);
  }
  return true;
}

#ifndef PADDLE_ONLY_CPU
26
template <typename T>
D
dongzhihong 已提交
27 28
bool Gaussian(platform::CUDADeviceContext* ctx,
              T* output,
29 30 31
              const int size,
              const T& mean,
              const T& std,
D
dongzhihong 已提交
32
              const T& seed) {
D
dongzhihong 已提交
33
  auto g = ctx->RandGenerator(seed);
D
dongzhihong 已提交
34 35 36
  return curandGenerateNormal(g, output, size, mean, std);
}
#endif
37

D
dongzhihong 已提交
38 39 40 41
template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel {
public:
  void Compute(const framework::KernelContext& context) const override {
D
dongzhihong 已提交
42 43 44
    auto mean = context.op_.GetAttr<T>("mean");
    auto std = context.op_.GetAttr<T>("std");
    auto seed = context.op_.GetAttr<T>("seed");
D
dongzhihong 已提交
45
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
D
dongzhihong 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    auto place = context.GetPlace();
    if (platform::is_cpu_place(place)) {
      Gaussian(
          dynamic_cast<platform::CPUDeviceContext*>(context.device_context_),
          output->mutable_data<T>(context.GetPlace()),
          framework::product(output->dims()),
          mean,
          std,
          seed);
    } else {
      Gaussian(
          dynamic_cast<platform::CUDADeviceContext*>(context.device_context_),
          output->mutable_data<T>(context.GetPlace()),
          framework::product(output->dims()),
          mean,
          std,
          seed);
    }
D
dongzhihong 已提交
64 65 66 67 68
  }
};

}  // namespace operators
}  // namespace paddle