未验证 提交 0c333543 编写于 作者: C Chang Xu 提交者: GitHub

Fix truncated norm operator (#40287)

上级 d7112180
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace paddle {
...@@ -118,9 +117,13 @@ class TruncatedGaussianInitializer : public Initializer { ...@@ -118,9 +117,13 @@ class TruncatedGaussianInitializer : public Initializer {
seed_ = static_cast<unsigned int>(std::stoi(attrs[1])); seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]); mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]); std_ = std::stof(attrs[3]);
auto normal_cdf = [](float x) {
std::uniform_real_distribution<float> dist_( return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
std::numeric_limits<float>::min(), 1.0); };
float a_normal_cdf = normal_cdf((-2.0 - mean_) / std_);
float b_normal_cdf = normal_cdf((2.0 - mean_) / std_);
std::uniform_real_distribution<float> dist_(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_); random_engine_ = framework::GetCPURandomEngine(seed_);
} }
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -140,19 +137,9 @@ T Erfinv(T x) { ...@@ -140,19 +137,9 @@ T Erfinv(T x) {
template <typename T> template <typename T>
struct TruncatedNormal { struct TruncatedNormal {
T mean, std; T mean, std;
T a_normal_cdf; TruncatedNormal(T mean, T std) : mean(mean), std(std) {}
T b_normal_cdf;
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
auto normal_cdf = [](T x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf(-2.0);
b_normal_cdf = normal_cdf(2.0);
}
T operator()(T value) const { T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; return std::sqrt(2.0) * Erfinv(value) * std + mean;
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
} }
}; };
......
...@@ -84,8 +84,13 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -84,8 +84,13 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
Tensor cpu_tensor(tensor->dtype()); Tensor cpu_tensor(tensor->dtype());
cpu_tensor.Resize(tensor->dims()); cpu_tensor.Resize(tensor->dims());
T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace()); T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(), auto normal_cdf = [](float x) {
1.0); return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std); TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel(); int64_t size = tensor->numel();
......
...@@ -32,8 +32,13 @@ class XPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -32,8 +32,13 @@ class XPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(), auto normal_cdf = [](float x) {
1.0); return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std); TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel(); int64_t size = tensor->numel();
......
...@@ -37,8 +37,13 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx, ...@@ -37,8 +37,13 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
T* data = dev_ctx.template Alloc<T>(tensor); T* data = dev_ctx.template Alloc<T>(tensor);
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(), auto normal_cdf = [](float x) {
1.0); return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std); TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel(); int64_t size = tensor->numel();
......
...@@ -33,23 +33,27 @@ struct GPUTruncatedNormal { ...@@ -33,23 +33,27 @@ struct GPUTruncatedNormal {
T mean, std; T mean, std;
T a_normal_cdf; T a_normal_cdf;
T b_normal_cdf; T b_normal_cdf;
unsigned int seed; unsigned int seed;
T numeric_min; T numeric_min;
__host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed) __host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed)
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) { : mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0; auto normal_cdf = [](float x) {
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0; return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf((-2.0 - mean) / std);
b_normal_cdf = normal_cdf((2.0 - mean) / std);
} }
__host__ __device__ T operator()(const unsigned int n) const { __host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1); thrust::uniform_real_distribution<T> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
rng.discard(n); rng.discard(n);
T value = dist(rng); T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; return std::sqrt(2.0) * erfinvf(value) * std + mean;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
} }
}; };
...@@ -69,18 +73,21 @@ struct TruncatedNormalOffset { ...@@ -69,18 +73,21 @@ struct TruncatedNormalOffset {
seed(seed), seed(seed),
numeric_min(numeric_min), numeric_min(numeric_min),
offset_(offset) { offset_(offset) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0; auto normal_cdf = [](float x) {
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0; return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf((-2.0 - mean) / std);
b_normal_cdf = normal_cdf((2.0 - mean) / std);
} }
__host__ __device__ T operator()(const unsigned int n) const { __host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1); thrust::uniform_real_distribution<T> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
rng.discard(n + offset_); rng.discard(n + offset_);
T value = dist(rng); T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; return std::sqrt(2.0) * erfinvf(value) * std + mean;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
} }
}; };
......
...@@ -141,19 +141,9 @@ T Erfinv(T x) { ...@@ -141,19 +141,9 @@ T Erfinv(T x) {
template <typename T> template <typename T>
struct TruncatedNormal { struct TruncatedNormal {
T mean, std; T mean, std;
T a_normal_cdf; TruncatedNormal(T mean, T std) : mean(mean), std(std) {}
T b_normal_cdf;
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
auto normal_cdf = [](T x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf(-2.0);
b_normal_cdf = normal_cdf(2.0);
}
T operator()(T value) const { T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; return std::sqrt(2.0) * Erfinv(value) * std + mean;
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册