未验证 提交 c98f144f 编写于 作者: C Chengmo 提交者: GitHub

add truncated gaussian random (#30922)

add truncated gaussian random
上级 4a8b8b45
......@@ -29,6 +29,8 @@ void CommonDenseTable::create_initializer(const std::string& attr,
initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices);
} else if (slices[0] == "truncated_gaussian_random") {
initializers_[name] = new TruncatedGaussianInitializer(slices);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name));
......
......@@ -16,6 +16,7 @@
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
......@@ -23,6 +24,8 @@
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle {
namespace distributed {
......@@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer {
std::normal_distribution<float> dist_;
};
class TruncatedGaussianInitializer : public Initializer {
public:
explicit TruncatedGaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
std::uniform_real_distribution<float> dist_(
std::numeric_limits<float>::min(), 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
float value = truncated_normal(dist_(*random_engine_));
return value;
}
void GetValue(float *value, int numel) {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
for (int x = 0; x < numel; ++x) {
value[x] = truncated_normal(dist_(*random_engine_));
}
}
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
class FillConstantInitializer : public Initializer {
public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
......
......@@ -134,6 +134,9 @@ class ValueBlock {
} else if (slices[0] == "uniform_random") {
initializers_.emplace_back(
std::make_shared<UniformInitializer>(slices));
} else if (slices[0] == "truncated_gaussian_random") {
initializers_.emplace_back(
std::make_shared<TruncatedGaussianInitializer>(slices));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s can not be supported", attr));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册