未验证 提交 7e651c86 编写于 作者: W whs 提交者: GitHub

Fix truncated norm (#13785)

* Fix truncated normal.

* test=develop
上级 16b1beb2
...@@ -148,7 +148,7 @@ struct TruncatedNormal { ...@@ -148,7 +148,7 @@ struct TruncatedNormal {
T operator()(T value) const { T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
} }
}; };
......
...@@ -42,7 +42,7 @@ struct TruncatedNormal { ...@@ -42,7 +42,7 @@ struct TruncatedNormal {
rng.discard(n); rng.discard(n);
T value = dist(rng); T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std; return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
} }
}; };
...@@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册