random_op.h 1.7 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {
D
dongzhihong 已提交
9 10 11 12
template <typename Place, typename T, typename Generator>
bool Gaussian(
    Generator g, T* output, const int size, const T& mean, const T& std);

D
dongzhihong 已提交
13 14 15 16 17 18 19
template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel {
public:
  void Compute(const framework::KernelContext& context) const override {
    auto mean = context.op_.attrs_.at("mean");
    auto std = context.op_.attrs_.at("std");
    auto seed = context.op_.attrs_.at("seed");
D
dongzhihong 已提交
20 21 22 23
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
    output->mutable_data<T>(context.GetPlace());

    Gaussian<Place, T, >(, output, output->size(), mean, std) :
D
dongzhihong 已提交
24 25 26
    // std::default_random_engine generator(seed);
    // std::normal_distribution<double> distribution(mean, std);

D
dongzhihong 已提交
27 28 29
    // framework::EigenMatrix<T>::From(*output).device(*(
    //     context.GetEigenDevice<Place>())) =
    //     framework::EigenMatrix<T>::Random();
D
dongzhihong 已提交
30 31 32
  }
};

D
dongzhihong 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
// using paddle::platform::CPUPlace;
// template<CPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
//   void Compute(const framework::KernelContext& context) const override {

//     std::unique_ptr<default_random_engine> generator(seed);
//     for(size_t i=0; i < output->size(); ++i) {
//       output[i] = distribution(generator());
//     }
//   }

// };

// using paddle::platform::GPUPlace;
// template<GPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
//   void Compute(const framework::KernelContext& context) const override {

//   }
// }

D
dongzhihong 已提交
56 57
}  // namespace operators
}  // namespace paddle