diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index fa9fe9d8602012f71ca6829e58561d03b7bfb2f1..21d827c79200c4a368ce7677b01b18ee4ddedb8d 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/distribution_helper.h" #include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/operators/index_impl.cu.h" DECLARE_bool(use_curand); @@ -65,7 +66,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); - thrust::counting_iterator index_sequence_begin(0); auto shape = GetShape(context); tensor->Resize(shape); @@ -88,15 +88,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } else { auto seed_offset = gen_cuda->IncrementOffset(1); int64_t gen_offset = size * seed_offset.second; - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed_offset.first, gen_offset)); + auto func = + GaussianGenerator(mean, std, seed_offset.first, gen_offset); + IndexKernel>(dev_cxt, tensor, func); } } else { - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed)); + auto func = GaussianGenerator(mean, std, seed); + IndexKernel>(dev_cxt, tensor, func); } } }; @@ -116,23 +114,22 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); - thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); int device_id = context.GetPlace().GetDeviceId(); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto& dev_cxt = + context.template device_context(); if (gen_cuda->GetIsInitPy() && seed_flag) { auto seed_offset = gen_cuda->IncrementOffset(1); int64_t gen_offset = size * seed_offset.second; - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed_offset.first, - seed_offset.second)); + auto func = GaussianGenerator(mean, std, seed_offset.first, + seed_offset.second); + IndexKernel>(dev_cxt, tensor, func); } else { - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed)); + auto func = GaussianGenerator(mean, std, seed); + IndexKernel>(dev_cxt, tensor, func); } } }; diff --git a/paddle/fluid/operators/index_impl.cu.h b/paddle/fluid/operators/index_impl.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..bae0d3f569f5f42a7ff63593352cd63579810606 --- /dev/null +++ b/paddle/fluid/operators/index_impl.cu.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/distribution_helper.h" +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +namespace paddle { +namespace operators { + +namespace kps = phi::kps; +template +__global__ void VectorizedIndexKernel(T *out, int numel, int main_offset, + Functor func) { + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + int args[VecSize]; + T result[VecSize]; + for (; data_offset < main_offset; data_offset += stride) { + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary(&result[0], &args[0], + func); + kps::WriteData(out + data_offset, &result[0], + BLOCK_NUM_X * VecSize); + } + int num = numel - data_offset; + if (numel > 0) { + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary(&result[0], &args[0], + func); + kps::WriteData(out + data_offset, &result[0], num); + } +} + +template +void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) { + int numel = out->numel(); + T *out_data = out->mutable_data(dev_ctx.GetPlace()); + if (numel <= 0) return; + int vec_size = paddle::platform::GetVectorizedSize((out->data())); +#ifdef PADDLE_WITH_XPU_KP + int block = 64; + int grid = 8; + auto stream = dev_ctx.x_context()->xpu_stream; +#else + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); + int grid = config.block_per_grid.x; + int block = config.thread_per_block.x; + auto stream = dev_ctx.stream(); +#endif + + int main_offset = (numel / (vec_size * block)) * vec_size * block; + switch (vec_size) { + case 4: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + case 2: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + case 1: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + default: { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cu b/paddle/fluid/operators/uniform_random_inplace_op.cu index a5231354eb47ea3b5cdd802a9b77f7ba7e313c1e..1c7b9a27f868821ceb20c720548b4df0ee6bcd40 100644 --- a/paddle/fluid/operators/uniform_random_inplace_op.cu +++ b/paddle/fluid/operators/uniform_random_inplace_op.cu @@ -12,130 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/uniform_random_op.h" +#include "paddle/phi/kernels/full_kernel.h" namespace paddle { namespace operators { - -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num, - int diag_step, T diag_val) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, - int diag_num, int diag_step, - T diag_val, int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - -template -__global__ void fill_value(int64_t size, T* data, float value) { - for (int idx = threadIdx.x; idx < size; idx += blockDim.x) { - data[idx] = static_cast(value); - } -} - -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random as uniform_random_op.cu. template class GPUUniformRandomInplaceKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto out_var = ctx.OutputVar("Out"); - auto* tensor = out_var->GetMutable(); - T* data = tensor->mutable_data(ctx.GetPlace()); - unsigned int seed = static_cast(ctx.Attr("seed")); - bool seed_flag = false; - if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } - - T min = static_cast(ctx.Attr("min")); - T max = static_cast(ctx.Attr("max")); - unsigned int diag_num = - static_cast(ctx.Attr("diag_num")); - unsigned int diag_step = - static_cast(ctx.Attr("diag_step")); - T diag_val = static_cast(ctx.Attr("diag_val")); - thrust::counting_iterator index_sequence_begin(0); - int64_t size = tensor->numel(); - int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && seed_flag) { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGeneratorOffset(min, max, seed_offset.first, diag_num, - diag_step, diag_val, gen_offset)); - } else { - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed, diag_num, diag_step, diag_val)); - } + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + UniformRandom(context, tensor); } }; @@ -143,17 +30,15 @@ template class GPUUniformRandomInplaceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif auto* dx = ctx.Output(framework::GradVarName("X")); - auto* data = dx->mutable_data(ctx.GetPlace()); - - auto size = dx->numel(); - int64_t kBlockDim = std::min(size, kMaxBlockDim); - fill_value<<<1, kBlockDim, 0>>>(size, data, static_cast(0)); + auto dims = vectorize(dx->dims()); + const auto& dev_cxt = + ctx.template device_context(); + float value = static_cast(0.0f); + phi::FullKernel( + static_cast::TYPE&>(dev_cxt), + dims, value, phi::DataType::UNDEFINED, dx); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 086c57527b48ffc940c029fb462afd6c22d86f98..fb38a6aded4cf173bb4c0dd96d131ff520b6701e 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,88 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/distribution_helper.h" #include "paddle/fluid/operators/uniform_random_op.h" -DECLARE_bool(use_curand); - namespace paddle { namespace operators { -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num, - int diag_step, T diag_val) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - -template -struct UniformGeneratorOffset { - T min_, max_; - unsigned int seed_; - T diag_val_; - unsigned int diag_num_; - unsigned int diag_step_; - int offset_; - __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, - int diag_num, int diag_step, - T diag_val, int offset) - : min_(min), - max_(max), - seed_(seed), - diag_num_(diag_num), - diag_step_(diag_step), - diag_val_(diag_val), - offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n + offset_); - T out = dist(rng); - unsigned int remainder = n % (diag_step_ + 1); - if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { - out = diag_val_; - } - return out; - } -}; - -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. template class GPUUniformRandomKernel : public framework::OpKernel { public: @@ -128,50 +51,7 @@ class GPUUniformRandomKernel : public framework::OpKernel { "unsupport type: %s.", framework::ToTypeName(out_var->Type()))); } - auto& dev_cxt = - context.template device_context(); - T* data = tensor->mutable_data(dev_cxt.GetPlace()); - unsigned int seed = static_cast(context.Attr("seed")); - bool seed_flag = false; - if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } - - T min = static_cast(context.Attr("min")); - T max = static_cast(context.Attr("max")); - unsigned int diag_num = - static_cast(context.Attr("diag_num")); - unsigned int diag_step = - static_cast(context.Attr("diag_step")); - T diag_val = static_cast(context.Attr("diag_val")); - thrust::counting_iterator index_sequence_begin(0); - int64_t size = tensor->numel(); - int device_id = context.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && seed_flag) { - if (FLAGS_use_curand) { - using MT = typename details::MPTypeTrait::Type; - distribution::uniform_distribution dist; - distribution::uniform_transform trans(min, max); - distribution::distribution_and_transform(dev_cxt, tensor, dist, - trans); - } else { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGeneratorOffset(min, max, seed_offset.first, diag_num, - diag_step, diag_val, gen_offset)); - } - } else { - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed, diag_num, diag_step, diag_val)); - } + UniformRandom(context, tensor); } }; diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index be6c3c740e692c17504fb36bd807c06768da2ee9..a864c48ad757411861b6d2b3be40361c347601f8 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -18,6 +18,16 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#if defined(__NVCC__) || defined(__HIPCC__) +DECLARE_bool(use_curand); +#include +#include +#include +#include +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/operators/index_impl.cu.h" +#include "paddle/phi/kernels/full_kernel.h" +#endif namespace paddle { namespace operators { @@ -102,5 +112,117 @@ inline std::vector GetNewDataFromShapeTensorList( return vec_new_shape; } + +#if defined(__NVCC__) || defined(__HIPCC__) + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num, + int diag_step, T diag_val) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +struct UniformGeneratorOffset { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + int offset_; + __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, + int diag_num, int diag_step, + T diag_val, int offset) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val), + offset_(offset) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n + offset_); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +void UniformRandom(const framework::ExecutionContext& context, + framework::Tensor* tensor) { + int64_t size = tensor->numel(); + auto& dev_cxt = + context.template device_context(); + T* data = tensor->mutable_data(dev_cxt.GetPlace()); + if (size <= 0) return; + unsigned int seed = static_cast(context.Attr("seed")); + bool seed_flag = false; + if (seed == 0) { + std::random_device rd; + seed = rd(); + seed_flag = true; + } + + T min = static_cast(context.Attr("min")); + T max = static_cast(context.Attr("max")); + unsigned int diag_num = + static_cast(context.Attr("diag_num")); + unsigned int diag_step = + static_cast(context.Attr("diag_step")); + T diag_val = static_cast(context.Attr("diag_val")); + int device_id = context.GetPlace().GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy() && seed_flag) { + if (FLAGS_use_curand) { + using MT = typename details::MPTypeTrait::Type; + distribution::uniform_distribution dist; + distribution::uniform_transform trans(min, max); + distribution::distribution_and_transform(dev_cxt, tensor, dist, trans); + } else { + auto seed_offset = gen_cuda->IncrementOffset(1); + int64_t gen_offset = size * seed_offset.second; + auto func = + UniformGeneratorOffset(min, max, seed_offset.first, diag_num, + diag_step, diag_val, gen_offset); + IndexKernel>(dev_cxt, tensor, func); + } + } else { + auto func = + UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); + IndexKernel>(dev_cxt, tensor, func); + } +} +#endif } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 120be251db2c80abf8975739e655eaed4f3425c3..a6c4c40a7505e11f9b2d971fdb48c4bc94ab5aa7 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -714,5 +714,14 @@ __device__ __forceinline__ void ReadDataBc( } } +template +__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { + int thread_offset = block_offset + threadIdx.x * NX; +#pragma unroll + for (int nx = 0; nx < NX; ++nx) { + dst[nx] = static_cast(thread_offset + nx); + } +} + } // namespace kps } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 70ab1cc523507edcdfb361beaaa7b44742ba10cd..43bcc3438eef49f2918be6ee460ee25db9f246e8 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -21,7 +21,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator from paddle.fluid.executor import Executor -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest import paddle