From b34663876056740261a9f58cf3e5d90e9e49788f Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 1 Mar 2022 11:25:24 +0800 Subject: [PATCH] [phi] move uniform_random to phi (#39937) * move uniform_random to phi * fit selected_rows * replace mutable_data --- paddle/fluid/framework/operator.cc | 3 + paddle/fluid/operators/uniform_random_op.cc | 4 - paddle/fluid/operators/uniform_random_op.cu | 3 - .../phi/kernels/cpu/uniform_random_kernel.cc | 115 ++++++++ paddle/phi/kernels/funcs/aligned_vector.h | 75 ++++++ .../phi/kernels/funcs/distribution_helper.h | 249 ++++++++++++++++++ paddle/phi/kernels/funcs/index_impl.cu.h | 93 +++++++ .../phi/kernels/gpu/uniform_random_kernel.cu | 163 ++++++++++++ .../selected_rows/uniform_random_kernel.cc | 88 +++++++ paddle/phi/kernels/uniform_random_kernel.h | 66 +++++ paddle/phi/ops/compat/uniform_random_sig.cc | 159 +++++++++++ 11 files changed, 1011 insertions(+), 7 deletions(-) create mode 100644 paddle/phi/kernels/cpu/uniform_random_kernel.cc create mode 100644 paddle/phi/kernels/funcs/aligned_vector.h create mode 100644 paddle/phi/kernels/funcs/distribution_helper.h create mode 100644 paddle/phi/kernels/funcs/index_impl.cu.h create mode 100644 paddle/phi/kernels/gpu/uniform_random_kernel.cu create mode 100644 paddle/phi/kernels/selected_rows/uniform_random_kernel.cc create mode 100644 paddle/phi/kernels/uniform_random_kernel.h create mode 100644 paddle/phi/ops/compat/uniform_random_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d33791f70c..36208c41ed 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2074,6 +2074,7 @@ void OperatorWithKernel::BuildPhiKernelContext( } pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); } + VLOG(4) << "Done inputs"; for (size_t i = 0; i < output_names.size(); ++i) { auto it = ctx.outputs.find(output_names[i]); @@ -2118,6 +2119,7 @@ void OperatorWithKernel::BuildPhiKernelContext( pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } + VLOG(4) << "Done outputs"; for (size_t i = 0; i < attr_names.size(); ++i) { if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { @@ -2226,6 +2228,7 @@ void OperatorWithKernel::BuildPhiKernelContext( } } } + VLOG(4) << "Done attributes"; } } // namespace framework diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 353d653f48..1c22e60fa8 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -281,10 +281,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::operators::UniformRandomOpVarTypeInference); -REGISTER_OP_CPU_KERNEL( - uniform_random, paddle::operators::CPUUniformRandomKernel, - paddle::operators::CPUUniformRandomKernel, - paddle::operators::CPUUniformRandomKernel); REGISTER_OP_CPU_KERNEL( uniform_random_batch_size_like, paddle::operators::CPUUniformRandomKernel, diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index fb38a6aded..2ceb8a68d8 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -58,9 +58,6 @@ class GPUUniformRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, paddle::operators::GPUUniformRandomKernel, paddle::operators::GPUUniformRandomKernel); diff --git a/paddle/phi/kernels/cpu/uniform_random_kernel.cc b/paddle/phi/kernels/cpu/uniform_random_kernel.cc new file mode 100644 index 0000000000..8ec1d9683e --- /dev/null +++ b/paddle/phi/kernels/cpu/uniform_random_kernel.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/uniform_random_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +inline void UniformRealDistribution(T *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(static_cast(min), + static_cast(max)); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + +template <> +inline void UniformRealDistribution(phi::dtype::bfloat16 *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(min, max); + for (int64_t i = 0; i < size; ++i) { + data[i] = static_cast(dist(*engine)); + } +} + +template +void UniformRandomRawKernel(const Context &dev_ctx, + const ScalarArray &shape, + DataType dtype, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor *out) { + out->Resize(phi::make_ddim(shape.GetData())); + VLOG(4) << out->dims(); + T *data = dev_ctx.template Alloc(out); + auto size = out->numel(); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = dev_ctx.GetGenerator()->GetCPUEngine(); + } + UniformRealDistribution(data, size, min, max, engine); + if (diag_num > 0) { + PADDLE_ENFORCE_GT( + size, + (diag_num - 1) * (diag_step + 1), + phi::errors::InvalidArgument( + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, + diag_step, + (diag_num - 1) * (diag_step + 1), + size)); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + data[pos] = diag_val; + } + } +} + +template +void UniformRandomKernel(const Context &dev_ctx, + const ScalarArray &shape, + DataType dtype, + float min, + float max, + int seed, + DenseTensor *out) { + UniformRandomRawKernel( + dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_raw, + CPU, + ALL_LAYOUT, + phi::UniformRandomRawKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(uniform_random, + CPU, + ALL_LAYOUT, + phi::UniformRandomKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/funcs/aligned_vector.h b/paddle/phi/kernels/funcs/aligned_vector.h new file mode 100644 index 0000000000..9382b03cf9 --- /dev/null +++ b/paddle/phi/kernels/funcs/aligned_vector.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.1 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.1 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/hostdevice.h" + +namespace phi { + +// Aligned vector generates vectorized load/store on CUDA. +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + + HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } + HOSTDEVICE inline T& operator[](int i) { return val[i]; } +}; + +template +HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { + const AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *vec = *addr_vec; +} + +template +HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { + AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *addr_vec = vec; +} + +/* +* Only the address of input data is the multiplier of 1,2,4, vectorized load +* with corresponding multiplier-value is possible. Moreover, the maximum length +* of vectorized load is 128 bits once. Hence, valid length of vectorized load +* shall be determined under both former constraints. +*/ +template +int GetVectorizedSize(const T* pointer) { + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + uint64_t address = reinterpret_cast(pointer); + constexpr int vec8 = std::alignment_of>::value; // NOLINT + constexpr int vec4 = std::alignment_of>::value; // NOLINT + constexpr int vec2 = std::alignment_of>::value; // NOLINT + if (address % vec8 == 0) { + /* + * Currently, decide to deal with no more than 4 data once while adopting + * vectorization load/store, if performance test shows that dealing with + * 8 data once in vectorization load/store does get optimized, return code + * below can be changed into " return std::min(8, valid_vec_size); " . + */ + return std::min(4, valid_vec_size); + } else if (address % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h new file mode 100644 index 0000000000..49e1c82482 --- /dev/null +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -0,0 +1,249 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +#endif + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/generator.h" + +#include "paddle/phi/kernels/funcs/index_impl.cu.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/primitive/kernel_primitives.h" +#endif + +#if !defined(_WIN32) +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition) +#endif + +namespace phi { +namespace distribution { + +/********************* Transformation Function **********************/ +template +struct exponential_transform { + explicit exponential_transform(T lambda) : lambda_(lambda) {} + + HOSTDEVICE inline T operator()(T val) const { +#if defined(__NVCC__) || defined(__HIPCC__) + if (std::is_same::value) { + return static_cast(-1.0) / lambda_ * log(val); + } else { + return static_cast(-1.0) / lambda_ * __logf(val); + } +#else + return static_cast(-1.0) / lambda_ * std::log(static_cast(1.0) - val); +#endif + } + + private: + T lambda_; +}; + +template +struct uniform_transform { + explicit uniform_transform(T min, T max) : range_(max - min), min_(min) {} + + HOSTDEVICE inline T operator()(T val) const { + if (UNLIKELY(val == static_cast(1.0))) { + return min_; + } else { + return val * range_ + min_; + } + } + + private: + T range_; + T min_; +}; + +template +struct normal_transform { + explicit normal_transform(T mean, T std) : mean_(mean), std_(std) {} + + HOSTDEVICE inline T operator()(T val) const { return val * std_ + mean_; } + + private: + T mean_; + T std_; +}; + +#if defined(__NVCC__) || defined(__HIPCC__) + +namespace kps = phi::kps; + +/*********************** Distribution Function *************************/ +template +struct uniform_distribution; + +template +struct normal_distribution; + +#if defined(__NVCC__) +template <> +struct uniform_distribution { + __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { + return curand_uniform4(state); + } + static constexpr int kReturnsCount = 4; +}; + +template <> +struct uniform_distribution { + __device__ inline double2 operator()( + curandStatePhilox4_32_10_t *state) const { + return curand_uniform2_double(state); + } + static constexpr int kReturnsCount = 2; +}; + +template <> +struct normal_distribution { + __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { + return curand_normal4(state); + } + static constexpr int kReturnsCount = 4; +}; + +template <> +struct normal_distribution { + __device__ inline double2 operator()( + curandStatePhilox4_32_10_t *state) const { + return curand_normal2_double(state); + } + static constexpr int kReturnsCount = 2; +}; + +#else +template <> +struct uniform_distribution { + __device__ inline float4 operator()( + hiprandStatePhilox4_32_10_t *state) const { + return hiprand_uniform4(state); + } + static constexpr int kReturnsCount = 4; +}; + +template <> +struct uniform_distribution { + __device__ inline double2 operator()( + hiprandStatePhilox4_32_10_t *state) const { + return hiprand_uniform2_double(state); + } + static constexpr int kReturnsCount = 2; +}; + +template <> +struct normal_distribution { + __device__ inline float4 operator()( + hiprandStatePhilox4_32_10_t *state) const { + return hiprand_normal4(state); + } + static constexpr int kReturnsCount = 4; +}; + +template <> +struct normal_distribution { + __device__ inline double2 operator()( + hiprandStatePhilox4_32_10_t *state) const { + return hiprand_normal2_double(state); + } + static constexpr int kReturnsCount = 2; +}; +#endif + +/******** Launch GPU function of distribution and transformation *********/ +template +__global__ void DistributionKernel(size_t size, + uint64_t seed, + uint64_t offset, + DistOp dist, + TransformOp trans, + T *out_data, + size_t stride) { + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = DistOp::kReturnsCount; +#if defined(__NVCC__) + curandStatePhilox4_32_10_t state; + curand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = curandStatePhilox4_32_10_t; +#else + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = hiprandStatePhilox4_32_10_t; +#endif + size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; + T args[kCount]; + T result[kCount]; + for (size_t i = idx; i < size; i += total_thread * kCount) { + kps::ElementwiseRandom(&args[0], dist, &state); + kps::ElementwiseUnary( + &result[0], &args[0], trans); + kps::WriteData( + out_data + i, &result[0], size - i, 1, stride, 1); + __syncthreads(); + } +} + +template +void distribution_and_transform(const GPUContext &dev_ctx, + DenseTensor *out, + DistOp dist, + TransformOp trans) { + T *out_data = dev_ctx.template Alloc(out); + auto size = out->numel(); + + int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); + auto gen_cuda = dev_ctx.GetGenerator(); + + size_t block_size = 256; + size_t expect_grid_size = (size + block_size - 1) / block_size; + const auto &prop = backends::gpu::GetDeviceProperties(device_id); + size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) * + prop.multiProcessorCount; + size_t grid_size = + expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size; + + size_t total_thread = block_size * grid_size; + size_t curand4_loop_times = + (size + 4 * total_thread - 1) / (4 * total_thread); + // 'increment' shoulde be multiple of 4 + uint64_t increment = curand4_loop_times * 4; + + auto seed_offset = gen_cuda->IncrementOffset(increment); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + + DistributionKernel< + T, + DistOp, + TransformOp><<>>( + size, seed, offset, dist, trans, out_data, total_thread); +} + +#endif +} // namespace distribution +} // namespace phi diff --git a/paddle/phi/kernels/funcs/index_impl.cu.h b/paddle/phi/kernels/funcs/index_impl.cu.h new file mode 100644 index 0000000000..ccb70fe25d --- /dev/null +++ b/paddle/phi/kernels/funcs/index_impl.cu.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +namespace phi { + +template +__global__ void VectorizedIndexKernel(T *out, + size_t numel, + size_t main_offset, + Functor func) { + size_t data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + size_t stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + size_t args[VecSize]; + T result[VecSize]; + for (; data_offset < main_offset; data_offset += stride) { + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary( + &result[0], &args[0], func); + kps::WriteData( + out + data_offset, &result[0], BLOCK_NUM_X * VecSize); + } + size_t num = numel - data_offset; + if (num > 0) { + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary( + &result[0], &args[0], func); + kps::WriteData(out + data_offset, &result[0], num); + } +} + +template +void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) { + int numel = out->numel(); + T *out_data = dev_ctx.template Alloc(out); + if (numel <= 0) return; + int vec_size = phi::GetVectorizedSize(out_data); +#ifdef PADDLE_WITH_XPU_KP + int block = 64; + int grid = 8; + auto stream = dev_ctx.x_context()->xpu_stream; +#else + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); + int grid = config.block_per_grid.x; + int block = config.thread_per_block.x; + auto stream = dev_ctx.stream(); +#endif + size_t main_offset = (numel / (vec_size * block)) * vec_size * block; + switch (vec_size) { + case 4: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + case 2: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + case 1: + VectorizedIndexKernel<<>>( + out_data, numel, main_offset, func); + break; + default: { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu new file mode 100644 index 0000000000..7f24a6667e --- /dev/null +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -0,0 +1,163 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/uniform_random_kernel.h" + +#include "gflags/gflags.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" + +DECLARE_bool(use_curand); + +namespace phi { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + __host__ __device__ UniformGenerator( + T min, T max, int seed, int diag_num, int diag_step, T diag_val) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +struct UniformGeneratorOffset { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + int offset_; + __host__ __device__ UniformGeneratorOffset(T min, + T max, + int seed, + int diag_num, + int diag_step, + T diag_val, + int offset) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val), + offset_(offset) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n + offset_); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +void UniformRandomRawKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out) { + out->Resize(phi::make_ddim(shape.GetData())); + T* data = dev_ctx.template Alloc(out); + auto size = out->numel(); + bool seed_flag = false; + if (seed == 0) { + std::random_device rd; + seed = rd(); + seed_flag = true; + } + + auto generator = dev_ctx.GetGenerator(); + if (generator->GetIsInitPy() && seed_flag) { + if (FLAGS_use_curand) { + using MT = typename kps::details::MPTypeTrait::Type; + distribution::uniform_distribution dist; + distribution::uniform_transform trans(min, max); + distribution::distribution_and_transform(dev_ctx, out, dist, trans); + } else { + auto seed_offset = generator->IncrementOffset(1); + int64_t gen_offset = size * seed_offset.second; + auto func = UniformGeneratorOffset(min, + max, + seed_offset.first, + diag_num, + diag_step, + diag_val, + gen_offset); + IndexKernel>(dev_ctx, out, func); + } + } else { + auto func = + UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); + IndexKernel>(dev_ctx, out, func); + } +} + +template +void UniformRandomKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + DenseTensor* out) { + UniformRandomRawKernel( + dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_raw, + GPU, + ALL_LAYOUT, + phi::UniformRandomRawKernel, + float, + double) {} + +PD_REGISTER_KERNEL( + uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {} diff --git a/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc b/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc new file mode 100644 index 0000000000..881180b71b --- /dev/null +++ b/paddle/phi/kernels/selected_rows/uniform_random_kernel.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UniformRandomRawSRKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + SelectedRows* out) { + phi::UniformRandomRawKernel(dev_ctx, + shape, + dtype, + min, + max, + seed, + diag_num, + diag_step, + diag_val, + out->mutable_value()); +} + +template +void UniformRandomSRKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + SelectedRows* out) { + phi::UniformRandomKernel( + dev_ctx, shape, dtype, min, max, seed, out->mutable_value()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_raw_sr, + CPU, + ALL_LAYOUT, + phi::UniformRandomRawSRKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(uniform_random_sr, + CPU, + ALL_LAYOUT, + phi::UniformRandomSRKernel, + float, + double, + phi::dtype::bfloat16) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +PD_REGISTER_KERNEL(uniform_random_raw_sr, + GPU, + ALL_LAYOUT, + phi::UniformRandomRawSRKernel, + float, + double) {} + +PD_REGISTER_KERNEL(uniform_random_sr, + GPU, + ALL_LAYOUT, + phi::UniformRandomSRKernel, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/uniform_random_kernel.h b/paddle/phi/kernels/uniform_random_kernel.h new file mode 100644 index 0000000000..5bba127278 --- /dev/null +++ b/paddle/phi/kernels/uniform_random_kernel.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void UniformRandomRawKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out); + +template +void UniformRandomKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + DenseTensor* out); + +template +void UniformRandomRawSRKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + SelectedRows* out); + +template +void UniformRandomSRKernel(const Context& dev_ctx, + const ScalarArray& shape, + DataType dtype, + float min, + float max, + int seed, + SelectedRows* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/uniform_random_sig.cc b/paddle/phi/ops/compat/uniform_random_sig.cc new file mode 100644 index 0000000000..d06d4026f4 --- /dev/null +++ b/paddle/phi/ops/compat/uniform_random_sig.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature UniformRandomOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int diag_num = paddle::any_cast(ctx.Attr("diag_num")); + if (ctx.IsDenseTensorOutput("Out")) { + if (diag_num) { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature("uniform_random_raw", + {}, + {"ShapeTensorList", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature("uniform_random_raw", + {}, + {"ShapeTensor", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } else { + return KernelSignature("uniform_random_raw", + {}, + {"shape", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } + } + } else { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature( + "uniform_random", + {}, + {"ShapeTensorList", "dtype", "min", "max", "seed"}, + {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature("uniform_random", + {}, + {"ShapeTensor", "dtype", "min", "max", "seed"}, + {"Out"}); + } else { + return KernelSignature("uniform_random", + {}, + {"shape", "dtype", "min", "max", "seed"}, + {"Out"}); + } + } + } + } else if (ctx.IsSelectedRowsOutput("Out")) { + if (diag_num) { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature("uniform_random_raw_sr", + {}, + {"ShapeTensorList", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature("uniform_random_raw_sr", + {}, + {"ShapeTensor", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } else { + return KernelSignature("uniform_random_raw_sr", + {}, + {"shape", + "dtype", + "min", + "max", + "seed", + "diag_num", + "diag_step", + "diag_val"}, + {"Out"}); + } + } + } else { + if (ctx.InputSize("ShapeTensorList") > 0) { + return KernelSignature( + "uniform_random_sr", + {}, + {"ShapeTensorList", "dtype", "min", "max", "seed"}, + {"Out"}); + } else { + const auto& shape = + paddle::any_cast>(ctx.Attr("shape")); + if (ctx.HasInput("ShapeTensor") && shape.empty()) { + return KernelSignature("uniform_random_sr", + {}, + {"ShapeTensor", "dtype", "min", "max", "seed"}, + {"Out"}); + } else { + return KernelSignature("uniform_random_sr", + {}, + {"shape", "dtype", "min", "max", "seed"}, + {"Out"}); + } + } + } + } + return KernelSignature("unregistered", {}, {}, {}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(uniform_random, phi::UniformRandomOpArgumentMapping); -- GitLab