From 4661f5589dc95a3bd3736848b820990c4c6e32d3 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 16:01:53 +0800 Subject: [PATCH] random optimize --- paddle/fluid/operators/sampling_id_op.cc | 44 +++++++++++++++--------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 4929a7edc2..f8f94553be 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); + unsigned int seed = static_cast(ctx.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(ctx.Attr("min")), + static_cast(ctx.Attr("max"))); + std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = getRandReal(); + double r = dist(engine); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { @@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel { output->mutable_data(context.GetPlace()); framework::TensorFromVector(ids, context.device_context(), output); } - - private: - double getRandReal() const { - std::random_device - rd; // Will be used to obtain a seed for the random number engine - std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with - // rd() - std::uniform_real_distribution<> dis(1.0, 2.0); - return dis(gen); - } }; class SamplingIdOp : public framework::OperatorWithKernel { @@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SamplingIdOp should not be null."); + PADDLE_ENFORCE( + ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), + "min must less then max"); auto input_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(input_dims.size() == 2, @@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( SamplingId Operator. A layer for sampling id from multinomial distribution from the - input layer. Sampling one id for one sample.)DOC"); + input. Sampling one id for one sample.)DOC"); + AddAttr("min", "Minimum value of random. [default 0.0].") + .SetDefault(0.0f); + AddAttr("max", "Maximun value of random. [default 1.0].") + .SetDefault(1.0f); + AddAttr("seed", + "Random seed used for the random number engine. " + "0 means use a seed generated by the system." + "Note that if seed is not 0, this operator will always " + "generate the same random numbers every time. [default 0].") + .SetDefault(0); } }; } // namespace operators @@ -109,8 +122,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - sampling_id, ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel, - ops::SamplingIdKernel); +REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); -- GitLab