From 4973e07be3fab37b7559b9a8abce12260a3233ea Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 14:01:21 +0800 Subject: [PATCH] sampling op optimize --- paddle/fluid/operators/sampling_id_op.cc | 14 +++++---- paddle/fluid/operators/sampling_id_op.cu | 14 ++++----- paddle/fluid/operators/sampling_id_op.h | 36 +++++++++++++----------- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index b9e3b0372d..9729537d1e 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -57,9 +57,11 @@ SamplingId Operator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - sampling_id, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + sampling_id, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index f82ba68ce4..e467165b6d 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL( - sampling_id, ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OP_CUDA_KERNEL( + sampling_id, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index eeb72d8f7d..5bb1991fc5 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -15,30 +15,31 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class SamplingIdKernel : public framework::OpKernel { - /// Produces random floating-point values, uniformly distributed on [0, 1). - std::uniform_real_distribution rand1_; - public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("X"); const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); - std::vector ids(batchSize); - auto& reng = get(); + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); - for (size_t i = 0; i < batchSize; ++i) { - double r = rand1_(reng); - int id = dim - 1; - for (int j = 0; j < dim; ++j) { - if ((r -= buf[i * dim + j]) < 0) { + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + double r = this->get_rand(); + int id = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { id = j; break; } @@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel { out_dim.push_back(static_cast(batch_size)); Tensor* output = context.Output("Output"); - output->Resize(framework::make_ddim(in_dim)); + output->Resize(framework::make_ddim(out_dim)); output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); } - std::default_random_engine& get() { - auto engine = new std::default_random_engine; - engine->seed(defaultSeed); - return *engine; + double get_rand() const { + // Will be used to obtain a seed for the random number engine + std::random_device rd; + // Standard mersenne_twister_engine seeded with rd() + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0, 1); + return dis(gen); } private: unsigned int defaultSeed = 0; -} +}; } // namespace operators } // namespace paddle -- GitLab