random_op.h 2.2 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {
D
dongzhihong 已提交
9

D
dongzhihong 已提交
10 11 12 13 14 15 16
template <typename T, typename DeviceContext>
bool Gaussian(DeviceContext& ctx,
              framework::Tensor* output,
              const int size,
              const T& mean,
              const T& std,
              const T& seed);
D
dongzhihong 已提交
17

D
dongzhihong 已提交
18 19 20 21
template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel {
public:
  void Compute(const framework::KernelContext& context) const override {
D
dongzhihong 已提交
22 23 24
    auto mean = context.op_.GetAttr<T>("mean");
    auto std = context.op_.GetAttr<T>("std");
    auto seed = context.op_.GetAttr<T>("seed");
D
dongzhihong 已提交
25 26
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
    output->mutable_data<T>(context.GetPlace());
D
dongzhihong 已提交
27 28 29 30 31 32 33 34 35 36
    Gaussian(context.device_context_,
             output,
             framework::product(output->dims()),
             mean,
             std,
             seed);
    // Gaussian<T, const platform::DeviceContext>(context.device_context_,
    // output,
    //                                            framework::product(output->dims()),
    //                                            mean, std, seed);
D
dongzhihong 已提交
37 38 39
    // std::default_random_engine generator(seed);
    // std::normal_distribution<double> distribution(mean, std);

D
dongzhihong 已提交
40 41 42
    // framework::EigenMatrix<T>::From(*output).device(*(
    //     context.GetEigenDevice<Place>())) =
    //     framework::EigenMatrix<T>::Random();
D
dongzhihong 已提交
43 44 45
  }
};

D
dongzhihong 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
// using paddle::platform::CPUPlace;
// template<CPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
//   void Compute(const framework::KernelContext& context) const override {

//     std::unique_ptr<default_random_engine> generator(seed);
//     for(size_t i=0; i < output->size(); ++i) {
//       output[i] = distribution(generator());
//     }
//   }

// };

// using paddle::platform::GPUPlace;
// template<GPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
//   void Compute(const framework::KernelContext& context) const override {

//   }
// }

D
dongzhihong 已提交
69 70
}  // namespace operators
}  // namespace paddle