From f9c9dc293bf2ba923e5d4df7f50d1624c164c2b0 Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Mon, 29 May 2023 10:31:58 +0800 Subject: [PATCH] add top_p_sampling (#54127) --- paddle/phi/api/yaml/ops.yaml | 9 + paddle/phi/infermeta/binary.cc | 20 + paddle/phi/infermeta/binary.h | 6 + .../phi/kernels/gpu/top_p_sampling_kernel.cu | 702 ++++++++++++++++++ paddle/phi/kernels/top_p_sampling_kernel.h | 29 + python/paddle/__init__.py | 2 + .../tests/unittests/test_top_p_sampling.py | 137 ++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/search.py | 35 + 9 files changed, 942 insertions(+) create mode 100644 paddle/phi/kernels/gpu/top_p_sampling_kernel.cu create mode 100644 paddle/phi/kernels/top_p_sampling_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_top_p_sampling.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 25bda07141c..9382083e320 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1921,6 +1921,15 @@ func : thresholded_relu backward : thresholded_relu_grad +- op : top_p_sampling + args : (Tensor x, Tensor ps, int random_seed=-1) + output : Tensor (out), Tensor(ids) + infer_meta : + func : TopPSamplingInferMeta + kernel : + func : top_p_sampling + data_type : x + - op : topk args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true) output : Tensor(out), Tensor(indices) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 86b44c9a3c5..802de589b57 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2742,6 +2742,26 @@ void TriangularSolveInferMeta(const MetaTensor& x, out->share_lod(y); } +void TopPSamplingInferMeta(const MetaTensor& x, + const MetaTensor& ps, + int random_seed, + MetaTensor* out, + MetaTensor* ids) { + auto x_dims = x.dims(); + auto ps_dims = ps.dims(); + PADDLE_ENFORCE_EQ(x_dims[0], + ps_dims[0], + phi::errors::InvalidArgument( + "The x_dims[0] must be equal to ps_dims[0] " + "But received x_dims[0] = %d and ps_dims[0] = %d.", + x_dims[0], + ps_dims[0])); + ids->set_dims(phi::make_ddim({x_dims[0], 1})); + ids->set_dtype(DataType::INT64); + out->set_dims(phi::make_ddim({x_dims[0], 1})); + out->set_dtype(x.dtype()); +} + void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, const Scalar& rcond, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index ed4da703ce5..49bbee914fc 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -428,6 +428,12 @@ void TriangularSolveInferMeta(const MetaTensor& x, bool unitriangular, MetaTensor* out); +void TopPSamplingInferMeta(const MetaTensor& x, + const MetaTensor& ps, + int random_seed, + MetaTensor* out, + MetaTensor* ids); + void LstsqInferMeta(const MetaTensor& x, const MetaTensor& y, const Scalar& rcond, diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu new file mode 100644 index 00000000000..3eb6b9e96ee --- /dev/null +++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu @@ -0,0 +1,702 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/top_p_sampling_kernel.h" + +#include +#include + +#include "cub/cub.cuh" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/top_k_function_cuda.h" + +// #define DEBUG_TOPP + +namespace phi { + +template +struct DataTypeTraits { + using DataType = T; +}; + +template <> +struct DataTypeTraits { + using DataType = half; +}; + +// template <> +// struct DataTypeTraits { +// using DataType = __nv_bfloat16; +// }; + +#define FINAL_MASK 0xFFFFFFFF + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + +namespace ops = paddle::operators; + +struct SegmentOffsetIter { + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + __host__ __device__ __forceinline__ int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +template +struct Pair { + __device__ __forceinline__ Pair() {} + __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} + + __device__ __forceinline__ void set(T value, int id) { + v = value; + id = id; + } + + __device__ __forceinline__ void operator=(const Pair& in) { + v = in.v; + id = in.id; + } + + __device__ __forceinline__ bool operator<(const T value) const { + return (static_cast(v) < static_cast(value)); + } + + __device__ __forceinline__ bool operator>(const T value) const { + return (static_cast(v) > static_cast(value)); + } + __device__ __forceinline__ bool operator<(const Pair& in) const { + return (static_cast(v) < static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id > in.id)); + } + + __device__ __forceinline__ bool operator>(const Pair& in) const { + return (static_cast(v) > static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id < in.id)); + } + + T v; + int id; +}; + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +__global__ void setup_kernel(curandState_t* state, + const uint64_t seed, + const int bs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + curand_init(seed + i, 0, 0, &state[i]); + } +} + +template +__device__ __forceinline__ void AddTo(Pair topk[], + const Pair& p, + int beam_size) { + for (int k = beam_size - 2; k >= 0; k--) { + if (topk[k] < p) { + topk[k + 1] = topk[k]; + } else { + topk[k + 1] = p; + return; + } + } + topk[0] = p; +} + +template +__device__ __forceinline__ void GetTopK( + Pair topk[], const T* src, int idx, int dim, int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + AddTo(topk, tmp, beam_size); + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], + const T* src, + int idx, + int dim, + const Pair& max, + int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + if (tmp < max) { + AddTo(topk, tmp, beam_size); + } + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], + int* beam, + int beam_size, + const T* src, + bool* firstStep, + bool* is_empty, + Pair* max, + int dim, + const int tid) { + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; + GetTopK(topk, src, tid, dim, length); + } else { + for (int k = 0; k < MaxLength; k++) { + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; + } else { + topk[k].set(std::numeric_limits::min(), -1); + } + } + if (!(*is_empty)) { + GetTopK( + topk + MaxLength - *beam, src, tid, dim, *max, length); + } + } + + *max = topk[MaxLength - 1]; + if ((*max).id == -1) *is_empty = true; + *beam = 0; + } +} + +template +__forceinline__ __device__ Pair WarpReduce(Pair input) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = + phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset); + int tmp_id = + phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset); + if (static_cast(input.v) < static_cast(tmp_val)) { + input.v = tmp_val; + input.id = tmp_id; + } + } + return input; +} + +template +__device__ __forceinline__ void BlockReduce(Pair shared_max[], + Pair topk[], + Pair beam_max[], + int* beam, + int* k, + int* count, + const int tid, + const int wid, + const int lane) { + while (true) { + __syncthreads(); + Pair input_now = topk[0]; + input_now = WarpReduce(input_now); + + if (lane == 0) { + shared_max[wid] = input_now; + } + __syncthreads(); + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(std::numeric_limits::min(), -1); + if (wid == 0) { + input_now = WarpReduce(input_now); + if (lane == 0) shared_max[0] = input_now; + } + __syncthreads(); + if (tid == 0) { + beam_max[*count] = shared_max[0]; + (*count)++; + } + int tid_max = shared_max[0].id % BlockSize; + if (tid == tid_max) { + (*beam)++; + } + if (--(*k) == 0) break; + __syncthreads(); + + if (tid == tid_max) { + if (*beam < MaxLength) { + topk[0] = topk[*beam]; + } + } + + if (MaxLength < 5) { + if (*beam >= MaxLength) break; + } else { + unsigned mask = 0u; + mask = __ballot_sync(FINAL_MASK, true); + if (tid_max / 32 == wid) { + if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) + break; + } + } + } +} + +template +__global__ void KeMatrixTopPBeamTopK(const T* src, + T* top_ps, + int64_t* out_id, // topk id + T* out_val, // topk val + int vocab_size, + curandState_t* state, + int* count_iter, + int* count_iter_begin) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const int bid = blockIdx.x; + + int top_num = TopPBeamTopK; + float top_p_num = static_cast(top_ps[bid]); + + __shared__ Pair shared_max[BlockSize / 32]; + __shared__ Pair beam_max[TopPBeamTopK]; + + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + __shared__ int count; + + if (tid == 0) { + count = 0; + } + + for (int j = 0; j < MaxLength; j++) { + topk[j].set(std::numeric_limits::min(), -1); + } + + while (top_num) { + ThreadGetTopK(topk, + &beam, + TopPBeamTopK, + src + bid * vocab_size, + &firststep, + &is_empty, + &max, + vocab_size, + tid); + BlockReduce( + shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); + } + if (tid == 0) { + count_iter_begin[bid] = count_iter[bid]; + float rand_top_p = curand_uniform(state + bid) * top_p_num; + top_ps[bid] = (T)rand_top_p; + float sum_prob = 0.0f; + for (int i = 0; i < TopPBeamTopK; i++) { + sum_prob += static_cast(beam_max[i].v); +#ifdef DEBUG_TOPP + printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n", + bid, + top_p_num, + rand_top_p, + sum_prob); +#endif + if (sum_prob >= rand_top_p) { + count_iter_begin[bid] += 1; + out_id[bid] = (int64_t)beam_max[i].id; + out_val[bid] = beam_max[i].v; +#ifdef DEBUG_TOPP + printf( + "bi: %d, early stop id: %d\n", bid, static_cast(out_id[bid])); +#endif + break; + } + } + } +} + +__global__ void SetCountIter(int* count_iter, int num) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int idx = bid * blockDim.x + tid; + for (int i = idx; i < num; i += gridDim.x * blockDim.x) { + count_iter[i] = i; + } +} + +template +__global__ void FillIndex(T* indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } + } +} + +struct BlockPrefixCallbackOp { + float running_total; + + __device__ BlockPrefixCallbackOp(float running_total) + : running_total(running_total) {} + + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +__global__ void topp_sampling(T* sorted_probs, + int64_t* sorted_id, + T* out_val, + int64_t* out_id, + const T* top_ps, + int p_num, + int vocab_size, + int* count_iter, + int* count_iter_begin) { + __shared__ int stop_shared; + __shared__ float rand_p; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + constexpr int NUM_WARPS = BLOCK_SIZE / 32; + const int lane_id = tid % 32; + const int warp_id = tid / 32; + const float p_t = static_cast(top_ps[bid]); + if (tid == 0) { + stop_shared = 0; + rand_p = p_t; +#ifdef DEBUG_TOPP + printf("bi: %d, p: %f\n", bid, rand_p); +#endif + } + if (count_iter_begin[bid] == count_iter[bid + 1]) { + // topk + return; + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ uint32_t selected_shared[NUM_WARPS]; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + if (lane_id == 0) { + selected_shared[warp_id] = 0; + } + __syncthreads(); + + int offset = bid * vocab_size; +#ifdef DEBUG_TOPP + if (tid == 0) { + printf( + "first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: " + "%d\n", + static_cast(sorted_probs[offset]), + static_cast(sorted_probs[offset + 1]), + static_cast(sorted_id[offset]), + static_cast(sorted_id[offset + 1]); + } +#endif + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int i_activate = 0; + float thread_offset = 0; + for (int i = tid; i < end; i += BLOCK_SIZE) { + float thread_count = + (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; + BlockScan(temp_storage) + .InclusiveSum(thread_count, thread_offset, prefix_op); + + uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset); + + i_activate = i; + if (activate_mask != 0) { + if (lane_id == 0) { + atomicAdd(&stop_shared, 1); + selected_shared[warp_id] = activate_mask; + } + } + __syncthreads(); + if (stop_shared > 0) { + break; + } + } + __syncthreads(); + if (stop_shared == 0) { + if (tid == 0) { + out_id[bid] = sorted_id[offset + vocab_size - 1]; + out_val[bid] = sorted_probs[offset + vocab_size - 1]; +#ifdef DEBUG_TOPP + printf("stop_shared: %d, out_id: %d, out_val: %f\n", + static_cast(stop_shared), + static_cast(out_id[bid]), + static_cast(out_val[bid]); +#endif + } + return; + } + +#ifdef DEBUG_TOPP + if (tid == 0) { + printf( + "first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: " + "%d\n", + static_cast(sorted_probs[offset]), + static_cast(sorted_probs[offset + 1]), + static_cast(sorted_id[offset]), + static_cast(sorted_id[offset + 1]); + } +#endif + bool skip = (selected_shared[warp_id] > 0) ? false : true; + for (int i = 0; i < warp_id; i++) { + if (selected_shared[i] != 0) { + skip = true; + } + } + if (!skip) { + int active_lane_id = + WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 + if (lane_id == active_lane_id) { +#ifdef DEBUG_TOPP + printf( + "active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate); + for (int i = 0; i < active_lane_id; i++) { + printf("p %d, value: %f\n", + i, + static_cast(sorted_probs[offset + i])); + } +#endif + out_id[bid] = sorted_id[offset + i_activate]; + out_val[bid] = sorted_probs[offset + i_activate]; + } + } +} + +int GetBlockSize(int vocab_size) { + if (vocab_size > 512) { + return 1024; + } else if (vocab_size > 256) { + return 512; + } else if (vocab_size > 128) { + return 256; + } else if (vocab_size > 64) { + return 128; + } else { + return 64; + } +} + +__global__ void set_sorted_num(int* need_sorted_num, int bs) { + *need_sorted_num = bs; +} + +template +__global__ void print_kernel(T* input, int size) { + printf("["); + for (int i = 0; i < size; i++) { + if (i != size - 1) { + printf("%f, ", static_cast(input[i])); + } else { + printf("%f]\n", static_cast(input[i])); + } + } +} + +template +void TopPSamplingKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& ps, + int random_seed, + DenseTensor* out, + DenseTensor* ids) { + typedef DataTypeTraits traits_; + typedef typename traits_::DataType DataType_; + auto cu_stream = dev_ctx.stream(); + const auto* input = &x; + // get the input dims + const auto& in_dims = input->dims(); + int p_num = ps.numel(); + int bs = in_dims[0]; + int vocab_size = in_dims[1]; + T* out_ptr = dev_ctx.template Alloc(out); + int64_t* ids_ptr = dev_ctx.template Alloc(ids); + + DenseTensor ps_now; + ps_now.Resize(phi::make_ddim({bs, 1})); + dev_ctx.template Alloc(&ps_now); + phi::Copy(dev_ctx, ps, dev_ctx.GetPlace(), false, &ps_now); + + DenseTensor inds_input; + inds_input.Resize(phi::make_ddim({bs, vocab_size})); + dev_ctx.template Alloc(&inds_input); + + DenseTensor sorted_out; + sorted_out.Resize(phi::make_ddim({bs, vocab_size})); + dev_ctx.template Alloc(&sorted_out); + + DenseTensor sorted_id; + sorted_id.Resize(phi::make_ddim({bs, vocab_size})); + dev_ctx.template Alloc(&sorted_id); + + int BlockSize = GetBlockSize(vocab_size); + + switch (BlockSize) { + FIXED_BLOCK_DIM(FillIndex<<>>( + inds_input.data(), bs, vocab_size)); + default: + PD_THROW("the input data shape has error in the FillIndex kernel."); + } + + curandState_t* dev_curand_states; + phi::Allocator::AllocationPtr curand_states_buf{nullptr}; + curand_states_buf = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + bs * sizeof(curandState_t), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + dev_curand_states = + reinterpret_cast(curand_states_buf->ptr()); + if (random_seed == -1) { + srand((unsigned int)(time(NULL))); + setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand(), bs); + } else { + setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, random_seed, bs); + } + + DenseTensor count_iter; + count_iter.Resize(phi::make_ddim({bs + 1})); + dev_ctx.template Alloc(&count_iter); + DenseTensor count_iter_begin; + count_iter_begin.Resize(phi::make_ddim({bs})); + dev_ctx.template Alloc(&count_iter_begin); + SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data(), bs + 1); + + constexpr int TopKMaxLength = 2; + constexpr int TopPBeamTopK = 10; + switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopK + <<>>(x.data(), + ps_now.data(), + ids_ptr, + out_ptr, + vocab_size, + dev_curand_states, + count_iter.data(), + count_iter_begin.data())); + default: + PD_THROW("the input data shape has error in the topp_beam_topk kernel."); + } + + size_t temp_storage_bytes = 0; + + cub::TransformInputIterator + segment_offsets_t_begin(count_iter_begin.data(), + SegmentOffsetIter(vocab_size)); + + cub::TransformInputIterator + segment_offsets_t_end(count_iter.data(), + SegmentOffsetIter(vocab_size)); + + cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + temp_storage_bytes, + reinterpret_cast(const_cast(x.data())), + reinterpret_cast(const_cast(sorted_out.data())), + inds_input.data(), + sorted_id.data(), + vocab_size * bs, + bs, + segment_offsets_t_begin, + segment_offsets_t_end + 1, + 0, + sizeof(T) * 8, + cu_stream); + + temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256; + int64_t temp_size = temp_storage_bytes; + DenseTensor temp_storage; + temp_storage.Resize(phi::make_ddim({temp_size})); + dev_ctx.template Alloc(&temp_storage); + + cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), + temp_storage_bytes, + reinterpret_cast(const_cast(x.data())), + reinterpret_cast(const_cast(sorted_out.data())), + inds_input.data(), + sorted_id.data(), + vocab_size * bs, + bs, + segment_offsets_t_begin, + segment_offsets_t_end + 1, + 0, + sizeof(T) * 8, + cu_stream); + switch (BlockSize) { + FIXED_BLOCK_DIM( + topp_sampling + <<>>(sorted_out.data(), + sorted_id.data(), + out_ptr, + ids_ptr, + ps_now.data(), + p_num, + vocab_size, + count_iter.data(), + count_iter_begin.data())); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); + } + return; +} + +} // namespace phi + +PD_REGISTER_KERNEL(top_p_sampling, + GPU, + ALL_LAYOUT, + phi::TopPSamplingKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/top_p_sampling_kernel.h b/paddle/phi/kernels/top_p_sampling_kernel.h new file mode 100644 index 00000000000..e5a2bff8c31 --- /dev/null +++ b/paddle/phi/kernels/top_p_sampling_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TopPSamplingKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& ps, + int random_seed, + DenseTensor* out, + DenseTensor* ids); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ca237df8e53..76b2434ef50 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -320,6 +320,7 @@ from .tensor.search import nonzero # noqa: F401 from .tensor.search import sort # noqa: F401 from .tensor.search import kthvalue # noqa: F401 from .tensor.search import mode # noqa: F401 +from .tensor.search import top_p_sampling # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401 @@ -542,6 +543,7 @@ __all__ = [ # noqa 'zeros_like', 'maximum', 'topk', + 'top_p_sampling', 'index_select', 'CPUPlace', 'matmul', diff --git a/python/paddle/fluid/tests/unittests/test_top_p_sampling.py b/python/paddle/fluid/tests/unittests/test_top_p_sampling.py new file mode 100644 index 00000000000..4a8544250ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_top_p_sampling.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + + +def TopPProcess(probs, top_p): + sorted_probs = paddle.sort(probs, descending=True) + sorted_indices = paddle.argsort(probs, descending=True) + cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) + + # Remove tokens with cumulative probs above the top_p, But keep at + # least min_tokens_to_keep tokens + sorted_indices_to_remove = cumulative_probs > top_p + + # Keep the first token + sorted_indices_to_remove = paddle.cast( + sorted_indices_to_remove, dtype='int64' + ) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + # Scatter sorted tensors to original indexing + sorted_indices = ( + sorted_indices + + paddle.arange(probs.shape[0]).unsqueeze(-1) * probs.shape[-1] + ) + condition = paddle.scatter( + sorted_indices_to_remove.flatten(), + sorted_indices.flatten(), + sorted_indices_to_remove.flatten(), + ) + condition = paddle.cast(condition, 'bool').reshape(probs.shape) + probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs) + next_tokens = paddle.multinomial(probs) + next_scores = paddle.index_sample(probs, next_tokens) + return next_scores, next_tokens + + +class TestTopPAPI(unittest.TestCase): + def setUp(self): + self.topp = 0.0 + self.seed = 6688 + self.batch_size = 3 + self.vocab_size = 10000 + self.dtype = "float32" + self.input_data = np.random.rand(self.batch_size, self.vocab_size) + + def run_dygraph(self, place): + with paddle.fluid.dygraph.guard(place): + input_tensor = paddle.to_tensor(self.input_data, self.dtype) + topp_tensor = paddle.to_tensor( + [ + self.topp, + ] + * self.batch_size, + self.dtype, + ).reshape((-1, 1)) + # test case for basic test case 1 + paddle_result = paddle.top_p_sampling( + input_tensor, topp_tensor, self.seed + ) + ref_res = TopPProcess(input_tensor, self.topp) + + np.testing.assert_allclose( + paddle_result[0].numpy(), ref_res[0].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + paddle_result[1].numpy().flatten(), + ref_res[1].numpy().flatten(), + rtol=0, + ) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input_tensor = paddle.static.data( + name="x", shape=[6, 1030], dtype=self.dtype + ) + topp_tensor = paddle.static.data( + name="topp", shape=[6, 1], dtype=self.dtype + ) + result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed) + ref_res = TopPProcess(input_tensor, self.topp) + exe = paddle.static.Executor(place) + input_data = np.random.rand(6, 1030).astype(self.dtype) + paddle_result = exe.run( + feed={ + "x": input_data, + "topp": np.array( + [ + self.topp, + ] + * 6 + ).astype(self.dtype), + }, + fetch_list=[ + result[0], + result[1], + ref_res[0], + ref_res[1], + ], + ) + np.testing.assert_allclose( + paddle_result[0], paddle_result[2], rtol=1e-05 + ) + np.testing.assert_allclose( + paddle_result[1], paddle_result[3], rtol=1e-05 + ) + + def test_cases(self): + places = [core.CUDAPlace(0)] + for place in places: + self.run_dygraph(place) + self.run_static(place) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b78ac0e57c2..a7087ae544b 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -278,6 +278,7 @@ from .search import index_sample # noqa: F401 from .search import masked_select # noqa: F401 from .search import kthvalue # noqa: F401 from .search import mode # noqa: F401 +from .search import top_p_sampling from .stat import mean # noqa: F401 from .stat import std # noqa: F401 @@ -468,6 +469,7 @@ tensor_method_func = [ # noqa 'argsort', 'masked_select', 'topk', + 'top_p_sampling', 'where', 'index_select', 'nonzero', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 881d95f8d96..20c834af585 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1129,3 +1129,38 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): ) indices.stop_gradient = True return values, indices + + +def top_p_sampling(x, ps, seed=None, name=None): + """ + Get the TopP scores and ids. + + Args: + x(Tensor): A N-D Tensor with type float32, float16 and bfloat16. + ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16. + seed(int, optional): the random seed, + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64. + """ + + if seed is None: + seed = -1 + + if in_dygraph_mode(): + return _C_ops.top_p_sampling(x, ps, seed) + + inputs = {"x": [x], "ps": [ps]} + attrs = {"seed": seed} + + helper = LayerHelper('top_p_sampling', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + ids = helper.create_variable_for_type_inference(dtype="int64") + helper.append_op( + type='top_p_sampling', + inputs=inputs, + outputs={'out': [out], 'ids': [ids]}, + attrs=attrs, + ) + return out, ids -- GitLab