diff --git a/paddle/operators/math/sampler.cc b/paddle/operators/math/sampler.cc new file mode 100644 index 0000000000000000000000000000000000000000..52628c3b03b6b4d05d7d0c1c1529ecb26f5f8776 --- /dev/null +++ b/paddle/operators/math/sampler.cc @@ -0,0 +1,47 @@ +#include "sampler.h" + +namespace paddle { +namespace random { + +Sampler::~Sampler() {} + +UniformSampler::UniformSampler(int64 range) + : Sampler(range), inv_range_(1.0 / range) { + std::random_device r; + random_engine_ = std::make_shared(r()); + dist_ = std::make_shared>(0, range); +} + +int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); } + +float UniformSampler::Probability(int64 value) const { return inv_range_; } + +LogUniformSampler::LogUniformSampler(int64 range) + : Sampler(range), log_range_(log(range + 1)) { + std::random_device r; + random_engine_ = std::make_shared(r()); + dist_ = std::make_shared>(0, 1); +} + +int64 LogUniformSampler::Sample() const { + // Got Log Uniform distribution from uniform distribution by + // inverse_transform_sampling method + // More details: + // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ + const int64 value = + static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; + // Mathematically, value should be <= range_, but might not be due to some + // floating point roundoff, so we mod by range_. + return value % range_; +} + +float LogUniformSampler::Probability(int64 value) const { + // Given f(x) = 1/[(x+1) * log_range_] + // The value's probability is integral of f(x) from value to (value + 1) + // More details: + // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler + return (log((value + 2.0) / (value + 1.0))) / log_range_; +} + +} // namespace random +} // namespace paddle diff --git a/paddle/operators/math/sampler.h b/paddle/operators/math/sampler.h new file mode 100644 index 0000000000000000000000000000000000000000..bcd7bead35284833b9833f617442dd9fad74ee68 --- /dev/null +++ b/paddle/operators/math/sampler.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +typedef long int64; +namespace paddle { +namespace operators { +namespace math { + +// TODO: Support for GPU + +/** +* Sample integers from [0, range). +*/ +class Sampler { + public: + explicit Sampler(int64 range) : range_(range) { /* check range > 0*/ + } + virtual ~Sampler(); + // Sample a single value + virtual int64 Sample() const = 0; + // The probability that a single call to Sample() returns the given value. + virtual float Probability(int64 value) const = 0; + + int64 range() { return range_; }; + + protected: + const int64 range_; +}; + +/** + * Sample integers from [0, range). + * And the distribution function is: + * P(x) = 1 / range + */ +class UniformSampler : public Sampler { + public: + explicit UniformSampler(int64 range); + + ~UniformSampler() override {} + + int64 Sample() const override; + + float Probability(int64 value) const override; + + private: + const float inv_range_; + std::shared_ptr random_engine_; + std::shared_ptr> dist_; +}; + +/** + * Sample integers from [0, range). + * And the distribution function is: + * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) + */ +class LogUniformSampler : public Sampler { + public: + explicit LogUniformSampler(int64 range); + + ~LogUniformSampler() override {} + + int64 Sample() const override; + + float Probability(int64 value) const override; + + private: + const float log_range_; + std::shared_ptr random_engine_; + std::shared_ptr> dist_; +}; + +} // math +} // namespace operators +} // namespace paddle