diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc index 45f8eed353dc73af4ccda7c511190b34a94d3b65..4063e4f501d011a498ecd8edf74ab719acb533b3 100644 --- a/paddle/fluid/distributed/table/common_dense_table.cc +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -29,6 +29,8 @@ void CommonDenseTable::create_initializer(const std::string& attr, initializers_[name] = new FillConstantInitializer(slices); } else if (slices[0] == "uniform_random") { initializers_[name] = new UniformInitializer(slices); + } else if (slices[0] == "truncated_gaussian_random") { + initializers_[name] = new TruncatedGaussianInitializer(slices); } else { PADDLE_THROW( platform::errors::InvalidArgument("%s can not be supported", name)); diff --git a/paddle/fluid/distributed/table/depends/initializers.h b/paddle/fluid/distributed/table/depends/initializers.h index e8857ed51560de1af34f55a4feca29d4e8b1292b..f46e659a88babb07918d02f1e05859829895f2bf 100644 --- a/paddle/fluid/distributed/table/depends/initializers.h +++ b/paddle/fluid/distributed/table/depends/initializers.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,8 @@ #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/operators/truncated_gaussian_random_op.h" + namespace paddle { namespace distributed { @@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer { std::normal_distribution dist_; }; +class TruncatedGaussianInitializer : public Initializer { + public: + explicit TruncatedGaussianInitializer(const std::vector &attrs) { + name_ = attrs[0]; + seed_ = static_cast(std::stoi(attrs[1])); + mean_ = std::stof(attrs[2]); + std_ = std::stof(attrs[3]); + + std::uniform_real_distribution dist_( + std::numeric_limits::min(), 1.0); + random_engine_ = framework::GetCPURandomEngine(seed_); + } + + float GetValue() override { + paddle::operators::TruncatedNormal truncated_normal(mean_, std_); + float value = truncated_normal(dist_(*random_engine_)); + return value; + } + + void GetValue(float *value, int numel) { + paddle::operators::TruncatedNormal truncated_normal(mean_, std_); + for (int x = 0; x < numel; ++x) { + value[x] = truncated_normal(dist_(*random_engine_)); + } + } + + private: + float std_; + float mean_; + + std::shared_ptr random_engine_; + std::uniform_real_distribution dist_; +}; + class FillConstantInitializer : public Initializer { public: explicit FillConstantInitializer(const std::vector &attrs) { diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index 9ab3711fe2ea0d053167582af838bdc2ba5fd5e1..55f8489b08cba04a132bba81c72ac34cf28a8ce2 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -134,6 +134,9 @@ class ValueBlock { } else if (slices[0] == "uniform_random") { initializers_.emplace_back( std::make_shared(slices)); + } else if (slices[0] == "truncated_gaussian_random") { + initializers_.emplace_back( + std::make_shared(slices)); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s can not be supported", attr));