From 0c37705ddc55fd391fca46bca162789ef6d7df22 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 7 Aug 2017 12:53:34 +0800 Subject: [PATCH] Use thrust to implement uniform_random --- paddle/operators/uniform_random_op.cc | 3 +- paddle/operators/uniform_random_op.cu | 53 +++++++++++++++++++++++++-- paddle/operators/uniform_random_op.h | 33 ++++++++++------- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index e3e1357818d..dec188f2a81 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -49,5 +49,4 @@ Used to initialize tensor with uniform random generator. } // namespace paddle REGISTER_OP(uniform_random, ops::RandomOp, ops::RandomOpMaker); -REGISTER_OP_CPU_KERNEL(uniform_random, - ops::UniformRandomKernel); +REGISTER_OP_CPU_KERNEL(uniform_random, ops::CPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 54ceaa14be4..89a274ae26c 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -12,7 +12,54 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include +#include +#include +#include +#include "paddle/operators/type_alias.h" -REGISTER_OP_GPU_KERNEL(uniform_random, - ops::UniformRandomKernel); +namespace paddle { +namespace operators { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUUniformRandomKernel : public OpKernel { + public: + void Compute(const ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + seed = std::random_device()(); + } + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(uniform_random, ops::GPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h index 66fceef9455..e86771b444a 100644 --- a/paddle/operators/uniform_random_op.h +++ b/paddle/operators/uniform_random_op.h @@ -13,25 +13,30 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/operators/type_alias.h" namespace paddle { namespace operators { -template -class UniformRandomKernel : public OpKernel { +template +class CPUUniformRandomKernel : public OpKernel { public: - void Compute(const ExecutionContext &context) const override { - auto tensor = context.Output(0); - tensor->mutable_data(context.GetPlace()); - - auto eigenTensor = EigenVector::Flatten(*tensor); - auto dev = context.GetEigenDevice(); - auto min = context.op_.GetAttr("min"); - auto max = context.op_.GetAttr("max"); - auto seed = static_cast(context.op_.GetAttr("seed")); - auto diff = max - min; - Eigen::internal::UniformRandomGenerator gen(seed); - eigenTensor.device(dev) = eigenTensor.random(gen) * diff + min; + void Compute(const ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist(static_cast(context.op_.GetAttr("min")), + static_cast(context.op_.GetAttr("max"))); + for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + data[i] = dist(engine); + } } }; -- GitLab