From 0d554f1dea499e72ce0e0d6c240aac0add23cf49 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 24 Jul 2017 21:01:57 +0800 Subject: [PATCH] "add template fill function" --- paddle/operators/random_op.cc | 14 +++++++++++- paddle/operators/random_op.cu | 13 ++++++++++++ paddle/operators/random_op.h | 40 +++++++++++++++++++++++++++++------ 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/paddle/operators/random_op.cc b/paddle/operators/random_op.cc index c219a0b67..b85ff8422 100644 --- a/paddle/operators/random_op.cc +++ b/paddle/operators/random_op.cc @@ -3,6 +3,18 @@ namespace paddle { namespace operators { + +using paddle::platform::GPUPlace; +template +bool Gaussian( + Generator g, T* output, const int size, const T& mean, const T& std) { + std::normal_distribution distribution(mean, std); + for (int i = 0; i < size; ++i) { + output[i] = distribution(g()); + } + return true; +} + class RandomOp : public framework::OperatorWithKernel { protected: void InferShape( @@ -12,7 +24,7 @@ protected: PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, "Inputs/Outputs of RandomOp must all be set."); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->set_dims(context.op_.attrs_.at("shape")); } }; diff --git a/paddle/operators/random_op.cu b/paddle/operators/random_op.cu index 50985f669..ea1096aeb 100644 --- a/paddle/operators/random_op.cu +++ b/paddle/operators/random_op.cu @@ -1,6 +1,19 @@ #include "paddle/operators/random_op.h" #include "paddle/framework/op_registry.h" +namespace paddle { +namespace operators { + +using paddle::platform::GPUPlace; +template +bool Gaussian(Generator g, T* output, const int size, const T& mean, const T& std) { + return curandGenerateNormal(g, output, size, mean, std); +} + +} // operators +} // paddle + + typedef paddle::operators::RandomOpKernel RandomOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(random_op, RandomOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h index 123d9d6ff..1b5fb16de 100644 --- a/paddle/operators/random_op.h +++ b/paddle/operators/random_op.h @@ -6,24 +6,52 @@ namespace paddle { namespace operators { +template +bool Gaussian( + Generator g, T* output, const int size, const T& mean, const T& std); + template class RandomOpKernel : public framework::OpKernel { public: void Compute(const framework::KernelContext& context) const override { - auto* output = context.Output(0)->GetMutable(); - output->mutable_data(context.GetPlace()); - - auto shape = context.op_.attrs_.at("Shape"); auto mean = context.op_.attrs_.at("mean"); auto std = context.op_.attrs_.at("std"); auto seed = context.op_.attrs_.at("seed"); + auto* output = context.Output(0)->GetMutable(); + output->mutable_data(context.GetPlace()); + + Gaussian(, output, output->size(), mean, std) : // std::default_random_engine generator(seed); // std::normal_distribution distribution(mean, std); - framework::EigenMatrix::From(*output).device(*( - context.GetEigenDevice())) = framework::EigenMatrix::Random(); + // framework::EigenMatrix::From(*output).device(*( + // context.GetEigenDevice())) = + // framework::EigenMatrix::Random(); } }; +// using paddle::platform::CPUPlace; +// template +// class RandomOpKernel : public framework::OpKernel { +// public: +// void Compute(const framework::KernelContext& context) const override { + +// std::unique_ptr generator(seed); +// for(size_t i=0; i < output->size(); ++i) { +// output[i] = distribution(generator()); +// } +// } + +// }; + +// using paddle::platform::GPUPlace; +// template +// class RandomOpKernel : public framework::OpKernel { +// public: +// void Compute(const framework::KernelContext& context) const override { + +// } +// } + } // namespace operators } // namespace paddle -- GitLab