From 7e651c8641f8f197aa127a7eef63c2e8eb403a71 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 11 Oct 2018 10:09:44 +0800 Subject: [PATCH] Fix truncated norm (#13785) * Fix truncated normal. * test=develop --- paddle/fluid/operators/truncated_gaussian_random_op.cc | 2 +- paddle/fluid/operators/truncated_gaussian_random_op.cu | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index d854e280397..1e8708f2648 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -148,7 +148,7 @@ struct TruncatedNormal { T operator()(T value) const { auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean; } }; diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cu b/paddle/fluid/operators/truncated_gaussian_random_op.cu index ad2a9021bfe..5a3510babe4 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cu +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cu @@ -42,7 +42,7 @@ struct TruncatedNormal { rng.discard(n); T value = dist(rng); auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean; } }; @@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = static_cast(context.Attr("seed")); if (seed == 0) { std::random_device rd; -- GitLab