truncated_gaussian_random_op.cu 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/random.h>
#include <thrust/transform.h>
#include <limits>
Y
yaoxuefeng 已提交
18
#include "paddle/fluid/framework/generator.h"
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

template <typename T>
struct TruncatedNormal {
  T mean, std;
  T a_normal_cdf;
  T b_normal_cdf;
  unsigned int seed;
  T numeric_min;

  __host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed)
      : mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
    a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
    b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
  }

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed);
    thrust::uniform_real_distribution<T> dist(numeric_min, 1);
    rng.discard(n);
    T value = dist(rng);
    auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
W
whs 已提交
46
    return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
47 48 49
  }
};

Y
yaoxuefeng 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename T>
struct TruncatedNormalOffset {
  T mean, std;
  T a_normal_cdf;
  T b_normal_cdf;
  unsigned int seed;
  T numeric_min;
  int offset_;

  __host__ __device__ TruncatedNormalOffset(T mean, T std, T numeric_min,
                                            int seed, int offset)
      : mean(mean),
        std(std),
        seed(seed),
        numeric_min(numeric_min),
        offset_(offset) {
    a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
    b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
  }

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed);
    thrust::uniform_real_distribution<T> dist(numeric_min, 1);
    rng.discard(n);
    T value = dist(rng);
    auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
    return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
  }
};

81 82 83 84 85 86
template <typename T>
class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* tensor = context.Output<framework::Tensor>("Out");
    T* data = tensor->mutable_data<T>(context.GetPlace());
W
whs 已提交
87

88
    unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
Y
yaoxuefeng 已提交
89
    bool seed_flag = false;
90 91 92
    if (seed == 0) {
      std::random_device rd;
      seed = rd();
Y
yaoxuefeng 已提交
93
      seed_flag = true;
94 95 96 97 98
    }
    T mean = static_cast<T>(context.Attr<float>("mean"));
    T std = static_cast<T>(context.Attr<float>("std"));
    thrust::counting_iterator<unsigned int> index_sequence_begin(0);
    int64_t size = tensor->numel();
Y
yaoxuefeng 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

    int device_id =
        BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
    auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

    if (gen_cuda->GetIsInitPy() && seed_flag) {
      auto seed_offset = gen_cuda->IncrementOffset(1);
      int offset_step = 100;
      // NOTE(xuefeng): Currently, we let offset step fixed to avoid
      // unexpected results which may cause ut fail.
      // we will fix this in future.
      int gen_offset = offset_step * seed_offset.second;
      thrust::transform(
          index_sequence_begin, index_sequence_begin + size,
          thrust::device_ptr<T>(data),
          TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
                                   seed_offset.first, seed_offset.second));
    }

118 119 120 121 122 123 124 125 126 127 128 129 130
    thrust::transform(
        index_sequence_begin, index_sequence_begin + size,
        thrust::device_ptr<T>(data),
        TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(
    truncated_gaussian_random,
    paddle::operators::GPUTruncatedGaussianRandomKernel<float>);