diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9382083e320decd33d91ae2d2581c23ec5aa9e86..fc5091c00da8a32656ceeb5cc85a12911faec520 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1922,13 +1922,14 @@ backward : thresholded_relu_grad - op : top_p_sampling - args : (Tensor x, Tensor ps, int random_seed=-1) + args : (Tensor x, Tensor ps, Tensor threshold, int random_seed=-1) output : Tensor (out), Tensor(ids) infer_meta : func : TopPSamplingInferMeta kernel : func : top_p_sampling data_type : x + optional : threshold - op : topk args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 802de589b579e3f441a048dddca05858f194b897..383d932975464acc7a234f96dc4aa128c64d1fdd 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void TopPSamplingInferMeta(const MetaTensor& x, const MetaTensor& ps, + const MetaTensor& threshold, int random_seed, MetaTensor* out, MetaTensor* ids) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 49bbee914fc368394cc3b25b79122246f1d1b791..21e6d159ae1c8f8e35ce2d4fd46799a2ab5fd250 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, void TopPSamplingInferMeta(const MetaTensor& x, const MetaTensor& ps, + const MetaTensor& threshold, int random_seed, MetaTensor* out, MetaTensor* ids); diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu index 3eb6b9e96eec8ca4b3c9afe41dfebd9d9d4ab653..1414b3c7194a144c0bc0fe89392469c6879b1ba5 100644 --- a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu +++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu @@ -41,11 +41,6 @@ struct DataTypeTraits { using DataType = half; }; -// template <> -// struct DataTypeTraits { -// using DataType = __nv_bfloat16; -// }; - #define FINAL_MASK 0xFFFFFFFF #define FIXED_BLOCK_DIM_BASE(dim, ...) \ @@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state, const int bs) { int idx = blockIdx.x * blockDim.x + threadIdx.x; for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { - curand_init(seed + i, 0, 0, &state[i]); + curand_init(seed, i, 0, &state[i]); } } @@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair shared_max[], template __global__ void KeMatrixTopPBeamTopK(const T* src, + const T* threshold, T* top_ps, int64_t* out_id, // topk id T* out_val, // topk val @@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, const int wid = tid / 32; const int lane = tid % 32; const int bid = blockIdx.x; + const float threshold_now = + threshold ? static_cast(threshold[bid]) : 0.f; int top_num = TopPBeamTopK; float top_p_num = static_cast(top_ps[bid]); @@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, float rand_top_p = curand_uniform(state + bid) * top_p_num; top_ps[bid] = (T)rand_top_p; float sum_prob = 0.0f; + for (int i = 0; i < TopPBeamTopK; i++) { - sum_prob += static_cast(beam_max[i].v); + float val = static_cast(beam_max[i].v); + sum_prob += val; #ifdef DEBUG_TOPP printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n", bid, @@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, #endif if (sum_prob >= rand_top_p) { count_iter_begin[bid] += 1; - out_id[bid] = (int64_t)beam_max[i].id; - out_val[bid] = beam_max[i].v; -#ifdef DEBUG_TOPP - printf( - "bi: %d, early stop id: %d\n", bid, static_cast(out_id[bid])); -#endif + if (val < threshold_now) { + // don't sample low score token + int start_id = i == 0 ? 0 : i - 1; + for (int j = start_id; j >= 0; j--) { + float val_now = static_cast(beam_max[j].v); + if (val_now >= threshold_now || j == 0) { + out_id[bid] = static_cast(beam_max[j].id); + out_val[bid] = beam_max[j].v; + break; + } + } + } else { + out_id[bid] = static_cast(beam_max[i].id); + out_val[bid] = beam_max[i].v; + } break; } } @@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) { } struct BlockPrefixCallbackOp { + // Running prefix float running_total; - + // Constructor __device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {} - + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. __device__ float operator()(float block_aggregate) { float old_prefix = running_total; running_total += block_aggregate; @@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp { } }; +template +__device__ T max_func(const T a, const T b) { + return a > b ? a : b; +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return max_func(a, b); + } +}; + template __global__ void topp_sampling(T* sorted_probs, int64_t* sorted_id, T* out_val, int64_t* out_id, const T* top_ps, - int p_num, - int vocab_size, + const T* threshold, + const uint64_t seed, + const int p_num, + const int vocab_size, int* count_iter, int* count_iter_begin) { __shared__ int stop_shared; @@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs, const int lane_id = tid % 32; const int warp_id = tid / 32; const float p_t = static_cast(top_ps[bid]); + const float threshold_now = + threshold ? static_cast(threshold[bid]) : 0.f; if (tid == 0) { stop_shared = 0; rand_p = p_t; @@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs, } typedef cub::BlockScan BlockScan; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage_reduce; __shared__ uint32_t selected_shared[NUM_WARPS]; + int threshold_id = 0; // Initialize running total BlockPrefixCallbackOp prefix_op(0); @@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs, __syncthreads(); int offset = bid * vocab_size; -#ifdef DEBUG_TOPP - if (tid == 0) { - printf( - "first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: " - "%d\n", - static_cast(sorted_probs[offset]), - static_cast(sorted_probs[offset + 1]), - static_cast(sorted_id[offset]), - static_cast(sorted_id[offset + 1]); - } -#endif int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int i_activate = 0; float thread_offset = 0; for (int i = tid; i < end; i += BLOCK_SIZE) { float thread_count = (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; + if (i < vocab_size && thread_count >= threshold_now) { + threshold_id = i; + } BlockScan(temp_storage) .InclusiveSum(thread_count, thread_offset, prefix_op); @@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs, __syncthreads(); if (stop_shared == 0) { if (tid == 0) { - out_id[bid] = sorted_id[offset + vocab_size - 1]; - out_val[bid] = sorted_probs[offset + vocab_size - 1]; -#ifdef DEBUG_TOPP - printf("stop_shared: %d, out_id: %d, out_val: %f\n", - static_cast(stop_shared), - static_cast(out_id[bid]), - static_cast(out_val[bid]); -#endif + out_id[bid] = sorted_id[offset]; + out_val[bid] = sorted_probs[offset]; } return; } - -#ifdef DEBUG_TOPP - if (tid == 0) { - printf( - "first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: " - "%d\n", - static_cast(sorted_probs[offset]), - static_cast(sorted_probs[offset + 1]), - static_cast(sorted_id[offset]), - static_cast(sorted_id[offset + 1]); - } -#endif bool skip = (selected_shared[warp_id] > 0) ? false : true; for (int i = 0; i < warp_id; i++) { if (selected_shared[i] != 0) { + // If the previous has stopped, skip the current warp skip = true; } } @@ -499,17 +505,20 @@ __global__ void topp_sampling(T* sorted_probs, int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 if (lane_id == active_lane_id) { -#ifdef DEBUG_TOPP - printf( - "active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate); - for (int i = 0; i < active_lane_id; i++) { - printf("p %d, value: %f\n", - i, - static_cast(sorted_probs[offset + i])); + float val = static_cast(sorted_probs[offset + i_activate]); + if (val < threshold_now) { + // don't sample low score token + int max_id = + BlockReduce(temp_storage_reduce).Reduce(threshold_id, MaxOp()); + curandStatePhilox4_32_10_t rng; + curand_init(seed, tid, 0, &rng); + int random_id = curand(&rng) % (max_id + 1); + out_id[bid] = sorted_id[offset + random_id]; + out_val[bid] = sorted_probs[offset + random_id]; + } else { + out_id[bid] = sorted_id[offset + i_activate]; + out_val[bid] = sorted_probs[offset + i_activate]; } -#endif - out_id[bid] = sorted_id[offset + i_activate]; - out_val[bid] = sorted_probs[offset + i_activate]; } } } @@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) { } } +template +T* SafeGetTensorPtr(const DenseTensor& t) { + return const_cast(t.data()); +} + +template +T* SafeGetTensorPtr(const DenseTensor* t) { + return t ? SafeGetTensorPtr(*t) : nullptr; +} + +template +T* SafeGetTensorPtr(const paddle::optional& t) { + return t ? SafeGetTensorPtr(t.get()) : nullptr; +} + template void TopPSamplingKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& ps, + const paddle::optional& threshold, int random_seed, DenseTensor* out, DenseTensor* ids) { @@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx, phi::Stream(reinterpret_cast(dev_ctx.stream()))); dev_curand_states = reinterpret_cast(curand_states_buf->ptr()); + unsigned int seed = 0; if (random_seed == -1) { - srand((unsigned int)(time(NULL))); - setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand(), bs); + rand_r(&seed); + setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, seed, bs); } else { - setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, random_seed, bs); + seed = random_seed; } DenseTensor count_iter; @@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx, dev_ctx.template Alloc(&count_iter_begin); SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data(), bs + 1); + T* threshold_data = SafeGetTensorPtr(threshold); + constexpr int TopKMaxLength = 2; constexpr int TopPBeamTopK = 10; switch (BlockSize) { FIXED_BLOCK_DIM( KeMatrixTopPBeamTopK <<>>(x.data(), + threshold_data, ps_now.data(), ids_ptr, out_ptr, @@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx, out_ptr, ids_ptr, ps_now.data(), + threshold_data, + seed, p_num, vocab_size, count_iter.data(), diff --git a/paddle/phi/kernels/top_p_sampling_kernel.h b/paddle/phi/kernels/top_p_sampling_kernel.h index e5a2bff8c315a3aec10d70d31a1bfa600c6c7aff..26a723179eeeb31a96660164e364ada8ed98bd92 100644 --- a/paddle/phi/kernels/top_p_sampling_kernel.h +++ b/paddle/phi/kernels/top_p_sampling_kernel.h @@ -22,6 +22,7 @@ template void TopPSamplingKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& ps, + const paddle::optional& threshold, int random_seed, DenseTensor* out, DenseTensor* ids); diff --git a/python/paddle/fluid/tests/unittests/test_top_p_sampling.py b/python/paddle/fluid/tests/unittests/test_top_p_sampling.py index 4a8544250ff7341b84eb7f644263d47a1f137305..2c882dd2cd65ad882e1a97a148bc40fed23d2dd7 100644 --- a/python/paddle/fluid/tests/unittests/test_top_p_sampling.py +++ b/python/paddle/fluid/tests/unittests/test_top_p_sampling.py @@ -53,6 +53,9 @@ def TopPProcess(probs, top_p): return next_scores, next_tokens +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA " +) class TestTopPAPI(unittest.TestCase): def setUp(self): self.topp = 0.0 @@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase): ).reshape((-1, 1)) # test case for basic test case 1 paddle_result = paddle.top_p_sampling( - input_tensor, topp_tensor, self.seed + input_tensor, topp_tensor, seed=self.seed ) ref_res = TopPProcess(input_tensor, self.topp) @@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase): topp_tensor = paddle.static.data( name="topp", shape=[6, 1], dtype=self.dtype ) - result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed) + result = paddle.top_p_sampling( + input_tensor, topp_tensor, seed=self.seed + ) ref_res = TopPProcess(input_tensor, self.topp) exe = paddle.static.Executor(place) input_data = np.random.rand(6, 1030).astype(self.dtype) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 20c834af585aef2932b1ae92d064be4221609151..b861df176d64c8d5ada615ef7a379dcb84edb1e7 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): return values, indices -def top_p_sampling(x, ps, seed=None, name=None): +def top_p_sampling(x, ps, threshold=None, seed=None, name=None): """ Get the TopP scores and ids. Args: x(Tensor): A N-D Tensor with type float32, float16 and bfloat16. ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16. + threshold(Tensor): A 1-D Tensor with type float32, float16 and bfloat16. seed(int, optional): the random seed, name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. @@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None): seed = -1 if in_dygraph_mode(): - return _C_ops.top_p_sampling(x, ps, seed) + return _C_ops.top_p_sampling(x, ps, threshold, seed) - inputs = {"x": [x], "ps": [ps]} - attrs = {"seed": seed} + inputs = {"x": x, "ps": ps, "threshold": threshold} + attrs = {"random_seed": seed} helper = LayerHelper('top_p_sampling', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None): helper.append_op( type='top_p_sampling', inputs=inputs, - outputs={'out': [out], 'ids': [ids]}, + outputs={'out': out, 'ids': ids}, attrs=attrs, ) return out, ids