uniform_random_op.cu 6.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
Q
qijun 已提交
16 17
#include <thrust/random.h>
#include <thrust/transform.h>
Y
yaoxuefeng 已提交
18
#include "paddle/fluid/framework/generator.h"
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
23
#include "paddle/fluid/operators/uniform_random_op.h"
24 25 26

DECLARE_bool(use_curand);

Q
qijun 已提交
27 28 29 30 31 32 33
namespace paddle {
namespace operators {

template <typename T>
struct UniformGenerator {
  T min_, max_;
  unsigned int seed_;
34 35 36 37 38 39 40 41 42 43 44
  T diag_val_;
  unsigned int diag_num_;
  unsigned int diag_step_;
  __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
                                       int diag_step, T diag_val)
      : min_(min),
        max_(max),
        seed_(seed),
        diag_num_(diag_num),
        diag_step_(diag_step),
        diag_val_(diag_val) {}
Q
qijun 已提交
45 46 47 48 49 50

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed_);
    thrust::uniform_real_distribution<T> dist(min_, max_);
    rng.discard(n);
51 52 53 54 55 56
    T out = dist(rng);
    unsigned int remainder = n % (diag_step_ + 1);
    if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
      out = diag_val_;
    }
    return out;
Q
qijun 已提交
57 58 59
  }
};

Y
yaoxuefeng 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
template <typename T>
struct UniformGeneratorOffset {
  T min_, max_;
  unsigned int seed_;
  T diag_val_;
  unsigned int diag_num_;
  unsigned int diag_step_;
  int offset_;
  __host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
                                             int diag_num, int diag_step,
                                             T diag_val, int offset)
      : min_(min),
        max_(max),
        seed_(seed),
        diag_num_(diag_num),
        diag_step_(diag_step),
        diag_val_(diag_val),
        offset_(offset) {}

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed_);
    thrust::uniform_real_distribution<T> dist(min_, max_);
    rng.discard(n + offset_);
    T out = dist(rng);
    unsigned int remainder = n % (diag_step_ + 1);
    if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
      out = diag_val_;
    }
    return out;
  }
};

Q
qijun 已提交
93 94 95 96
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
Y
Yu Yang 已提交
97
class GPUUniformRandomKernel : public framework::OpKernel<T> {
Q
qijun 已提交
98 99
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
Yancey1989 已提交
100
    framework::Tensor* tensor = nullptr;
Y
fix ci  
Yancey1989 已提交
101
    auto out_var = context.OutputVar("Out");
102 103 104 105 106 107
    std::vector<int64_t> new_shape;
    auto list_new_shape_tensor =
        context.MultiInput<framework::Tensor>("ShapeTensorList");
    if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) {
      if (context.HasInput("ShapeTensor")) {
        auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor");
108
        new_shape = GetNewDataFromShapeTensor(shape_tensor);
109
      } else if (list_new_shape_tensor.size() > 0) {
110
        new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
111 112 113
      }
    }

114 115
    if (out_var->IsType<phi::SelectedRows>()) {
      auto* selected_rows = out_var->GetMutable<phi::SelectedRows>();
116
      tensor = selected_rows->mutable_value();
T
tangwei12 已提交
117
      auto shape = context.Attr<std::vector<int64_t>>("shape");
118
      if (!new_shape.empty()) shape = new_shape;
119
      tensor->Resize(phi::make_ddim(shape));
120 121 122
      selected_rows->mutable_rows()->reserve(shape[0]);
    } else if (out_var->IsType<framework::LoDTensor>()) {
      tensor = out_var->GetMutable<framework::LoDTensor>();
123
      if (!new_shape.empty()) tensor->Resize(phi::make_ddim(new_shape));
Y
Yancey1989 已提交
124
    } else {
125 126 127 128 129
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Expected type of Output(out) in uniform_random_op must be Tensor, "
          "SelectedRows. But got "
          "unsupport type: %s.",
          framework::ToTypeName(out_var->Type())));
Y
Yancey1989 已提交
130
    }
131 132 133
    auto& dev_cxt =
        context.template device_context<platform::CUDADeviceContext>();
    T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());
Y
Pass CI  
Yu Yang 已提交
134
    unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
Y
yaoxuefeng 已提交
135
    bool seed_flag = false;
L
Leo Chen 已提交
136 137 138
    if (seed == 0) {
      std::random_device rd;
      seed = rd();
Y
yaoxuefeng 已提交
139
      seed_flag = true;
Q
qijun 已提交
140
    }
L
Leo Chen 已提交
141

Y
Yu Yang 已提交
142 143
    T min = static_cast<T>(context.Attr<float>("min"));
    T max = static_cast<T>(context.Attr<float>("max"));
144 145 146 147 148
    unsigned int diag_num =
        static_cast<unsigned int>(context.Attr<int>("diag_num"));
    unsigned int diag_step =
        static_cast<unsigned int>(context.Attr<int>("diag_step"));
    T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
Y
Yang 已提交
149
    thrust::counting_iterator<int64_t> index_sequence_begin(0);
150
    int64_t size = tensor->numel();
151
    int device_id = context.GetPlace().GetDeviceId();
Y
yaoxuefeng 已提交
152 153
    auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
    if (gen_cuda->GetIsInitPy() && seed_flag) {
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
      if (FLAGS_use_curand) {
        using MT = typename details::MPTypeTrait<T>::Type;
        distribution::uniform_distribution<MT> dist;
        distribution::uniform_transform<MT> trans(min, max);
        distribution::distribution_and_transform<T>(dev_cxt, tensor, dist,
                                                    trans);
      } else {
        auto seed_offset = gen_cuda->IncrementOffset(1);
        int64_t gen_offset = size * seed_offset.second;
        thrust::transform(
            index_sequence_begin, index_sequence_begin + size,
            thrust::device_ptr<T>(data),
            UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
                                      diag_step, diag_val, gen_offset));
      }
Y
yaoxuefeng 已提交
169 170 171 172 173 174
    } else {
      thrust::transform(
          index_sequence_begin, index_sequence_begin + size,
          thrust::device_ptr<T>(data),
          UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
    }
Q
qijun 已提交
175 176 177 178 179
  }
};

}  // namespace operators
}  // namespace paddle
Y
Yu Yang 已提交
180

181 182 183 184 185 186
REGISTER_OP_CUDA_KERNEL(uniform_random,
                        paddle::operators::GPUUniformRandomKernel<float>,
                        paddle::operators::GPUUniformRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like,
                        paddle::operators::GPUUniformRandomKernel<float>,
                        paddle::operators::GPUUniformRandomKernel<double>);