/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { template struct UniformGenerator { T min_, max_; unsigned int seed_; T diag_val_; unsigned int diag_num_; unsigned int diag_step_; __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num, int diag_step, T diag_val) : min_(min), max_(max), seed_(seed), diag_num_(diag_num), diag_step_(diag_step), diag_val_(diag_val) {} __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); thrust::uniform_real_distribution dist(min_, max_); rng.discard(n); T out = dist(rng); unsigned int remainder = n % (diag_step_ + 1); if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { out = diag_val_; } return out; } }; template struct UniformGeneratorOffset { T min_, max_; unsigned int seed_; T diag_val_; unsigned int diag_num_; unsigned int diag_step_; int offset_; __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, int diag_num, int diag_step, T diag_val, int offset) : min_(min), max_(max), seed_(seed), diag_num_(diag_num), diag_step_(diag_step), diag_val_(diag_val), offset_(offset) {} __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); thrust::uniform_real_distribution dist(min_, max_); rng.discard(n + offset_); T out = dist(rng); unsigned int remainder = n % (diag_step_ + 1); if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { out = diag_val_; } return out; } }; template __global__ void fill_value(int64_t size, T* data, float value) { for (int idx = threadIdx.x; idx < size; idx += blockDim.x) { data[idx] = static_cast(value); } } // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random as uniform_random_op.cu. template class GPUUniformRandomInplaceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto out_var = ctx.OutputVar("Out"); auto* tensor = out_var->GetMutable(); T* data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); bool seed_flag = false; if (seed == 0) { std::random_device rd; seed = rd(); seed_flag = true; } T min = static_cast(ctx.Attr("min")); T max = static_cast(ctx.Attr("max")); unsigned int diag_num = static_cast(ctx.Attr("diag_num")); unsigned int diag_step = static_cast(ctx.Attr("diag_step")); T diag_val = static_cast(ctx.Attr("diag_val")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); int device_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); if (gen_cuda->GetIsInitPy() && seed_flag) { auto seed_offset = gen_cuda->IncrementOffset(1); int gen_offset = size * seed_offset.second; thrust::transform( index_sequence_begin, index_sequence_begin + size, thrust::device_ptr(data), UniformGeneratorOffset(min, max, seed_offset.first, diag_num, diag_step, diag_val, gen_offset)); } else { thrust::transform( index_sequence_begin, index_sequence_begin + size, thrust::device_ptr(data), UniformGenerator(min, max, seed, diag_num, diag_step, diag_val)); } } }; template class GPUUniformRandomInplaceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #ifdef __HIPCC__ const int64_t kMaxBlockDim = 256; #else const int64_t kMaxBlockDim = 512; #endif auto* dx = ctx.Output(framework::GradVarName("X")); auto* data = dx->mutable_data(ctx.GetPlace()); auto size = dx->numel(); int64_t kBlockDim = std::min(size, kMaxBlockDim); fill_value<<<1, kBlockDim, 0>>>(size, data, static_cast(0)); } }; } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( uniform_random_inplace, paddle::operators::GPUUniformRandomInplaceKernel, paddle::operators::GPUUniformRandomInplaceKernel); REGISTER_OP_CUDA_KERNEL( uniform_random_inplace_grad, paddle::operators::GPUUniformRandomInplaceGradKernel, paddle::operators::GPUUniformRandomInplaceGradKernel);