diff --git a/paddle/operators/math/sampler.cc b/paddle/operators/math/sampler.cc index 52628c3b03b6b4d05d7d0c1c1529ecb26f5f8776..4f1cbfe31ac68499a51eda600b38b879f7ca055f 100644 --- a/paddle/operators/math/sampler.cc +++ b/paddle/operators/math/sampler.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "sampler.h" namespace paddle { @@ -7,8 +21,13 @@ Sampler::~Sampler() {} UniformSampler::UniformSampler(int64 range) : Sampler(range), inv_range_(1.0 / range) { - std::random_device r; - random_engine_ = std::make_shared(r()); + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, range); +} + +UniformSampler::UniformSampler(int64 range, unsigned int seed) + : Sampler(range, seed), inv_range_(1.0 / range) { + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, range); } @@ -18,11 +37,15 @@ float UniformSampler::Probability(int64 value) const { return inv_range_; } LogUniformSampler::LogUniformSampler(int64 range) : Sampler(range), log_range_(log(range + 1)) { - std::random_device r; - random_engine_ = std::make_shared(r()); + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, 1); } +LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed) + : Sampler(range, seed), log_range_(log(range + 1)) { + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, 1); +} int64 LogUniformSampler::Sample() const { // Got Log Uniform distribution from uniform distribution by // inverse_transform_sampling method diff --git a/paddle/operators/math/sampler.h b/paddle/operators/math/sampler.h index bcd7bead35284833b9833f617442dd9fad74ee68..8f82089e7bd9e0ae6282459b650c225d6765faee 100644 --- a/paddle/operators/math/sampler.h +++ b/paddle/operators/math/sampler.h @@ -20,14 +20,21 @@ namespace paddle { namespace operators { namespace math { -// TODO: Support for GPU +// TODO(wanghaoshuang): Support for GPU /** * Sample integers from [0, range). */ class Sampler { public: - explicit Sampler(int64 range) : range_(range) { /* check range > 0*/ + explicit Sampler(int64 range) : range_(range) { + PADDLE_ENFORCE_GT(range, 0); + std::random_device r; + seed_ = r(); + } + explicit Sampler(int64 range, unsigned int seed) + : range_(range), seed_(seed) { + PADDLE_ENFORCE_GT(range, 0); } virtual ~Sampler(); // Sample a single value @@ -39,6 +46,7 @@ class Sampler { protected: const int64 range_; + unsigned int seed_; }; /** @@ -50,6 +58,8 @@ class UniformSampler : public Sampler { public: explicit UniformSampler(int64 range); + explicit UniformSampler(int64 range, unsigned int seed); + ~UniformSampler() override {} int64 Sample() const override; @@ -71,6 +81,8 @@ class LogUniformSampler : public Sampler { public: explicit LogUniformSampler(int64 range); + explicit LogUniformSampler(int64 range, unsigned int seed); + ~LogUniformSampler() override {} int64 Sample() const override;