未验证 提交 0c333543 编写于 作者: C Chang Xu 提交者: GitHub

Fix truncated norm operator (#40287)

上级 d7112180
......@@ -23,7 +23,6 @@
#include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle {
......@@ -118,9 +117,13 @@ class TruncatedGaussianInitializer : public Initializer {
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
std::uniform_real_distribution<float> dist_(
std::numeric_limits<float>::min(), 1.0);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean_) / std_);
float b_normal_cdf = normal_cdf((2.0 - mean_) / std_);
std::uniform_real_distribution<float> dist_(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -140,19 +137,9 @@ T Erfinv(T x) {
template <typename T>
struct TruncatedNormal {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
auto normal_cdf = [](T x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf(-2.0);
b_normal_cdf = normal_cdf(2.0);
}
TruncatedNormal(T mean, T std) : mean(mean), std(std) {}
T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
return std::sqrt(2.0) * Erfinv(value) * std + mean;
}
};
......
......@@ -84,8 +84,13 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
Tensor cpu_tensor(tensor->dtype());
cpu_tensor.Resize(tensor->dims());
T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
......
......@@ -32,8 +32,13 @@ class XPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
......
......@@ -37,8 +37,13 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
T* data = dev_ctx.template Alloc<T>(tensor);
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
......
......@@ -33,23 +33,27 @@ struct GPUTruncatedNormal {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
unsigned int seed;
T numeric_min;
__host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed)
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf((-2.0 - mean) / std);
b_normal_cdf = normal_cdf((2.0 - mean) / std);
}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
thrust::uniform_real_distribution<T> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
rng.discard(n);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
return std::sqrt(2.0) * erfinvf(value) * std + mean;
}
};
......@@ -69,18 +73,21 @@ struct TruncatedNormalOffset {
seed(seed),
numeric_min(numeric_min),
offset_(offset) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf((-2.0 - mean) / std);
b_normal_cdf = normal_cdf((2.0 - mean) / std);
}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
thrust::uniform_real_distribution<T> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
rng.discard(n + offset_);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
return std::sqrt(2.0) * erfinvf(value) * std + mean;
}
};
......
......@@ -141,19 +141,9 @@ T Erfinv(T x) {
template <typename T>
struct TruncatedNormal {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
auto normal_cdf = [](T x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf(-2.0);
b_normal_cdf = normal_cdf(2.0);
}
TruncatedNormal(T mean, T std) : mean(mean), std(std) {}
T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
return std::sqrt(2.0) * Erfinv(value) * std + mean;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册