未验证 提交 e1545af4 编写于 作者: L lzy 提交者: GitHub

make top_p_sampling supports threshold (#55486)

* make top_p_sampling supports threshold

* delete __nv_bfloat16
上级 0252287e
...@@ -1922,13 +1922,14 @@ ...@@ -1922,13 +1922,14 @@
backward : thresholded_relu_grad backward : thresholded_relu_grad
- op : top_p_sampling - op : top_p_sampling
args : (Tensor x, Tensor ps, int random_seed=-1) args : (Tensor x, Tensor ps, Tensor threshold, int random_seed=-1)
output : Tensor (out), Tensor(ids) output : Tensor (out), Tensor(ids)
infer_meta : infer_meta :
func : TopPSamplingInferMeta func : TopPSamplingInferMeta
kernel : kernel :
func : top_p_sampling func : top_p_sampling
data_type : x data_type : x
optional : threshold
- op : topk - op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true) args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
......
...@@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void TopPSamplingInferMeta(const MetaTensor& x, void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps, const MetaTensor& ps,
const MetaTensor& threshold,
int random_seed, int random_seed,
MetaTensor* out, MetaTensor* out,
MetaTensor* ids) { MetaTensor* ids) {
......
...@@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void TopPSamplingInferMeta(const MetaTensor& x, void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps, const MetaTensor& ps,
const MetaTensor& threshold,
int random_seed, int random_seed,
MetaTensor* out, MetaTensor* out,
MetaTensor* ids); MetaTensor* ids);
......
...@@ -41,11 +41,6 @@ struct DataTypeTraits<phi::dtype::float16> { ...@@ -41,11 +41,6 @@ struct DataTypeTraits<phi::dtype::float16> {
using DataType = half; using DataType = half;
}; };
// template <>
// struct DataTypeTraits<phi::dtype::bfloat16> {
// using DataType = __nv_bfloat16;
// };
#define FINAL_MASK 0xFFFFFFFF #define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \ #define FIXED_BLOCK_DIM_BASE(dim, ...) \
...@@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state, ...@@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state,
const int bs) { const int bs) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
curand_init(seed + i, 0, 0, &state[i]); curand_init(seed, i, 0, &state[i]);
} }
} }
...@@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[], ...@@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize> template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopK(const T* src, __global__ void KeMatrixTopPBeamTopK(const T* src,
const T* threshold,
T* top_ps, T* top_ps,
int64_t* out_id, // topk id int64_t* out_id, // topk id
T* out_val, // topk val T* out_val, // topk val
...@@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, ...@@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
const int wid = tid / 32; const int wid = tid / 32;
const int lane = tid % 32; const int lane = tid % 32;
const int bid = blockIdx.x; const int bid = blockIdx.x;
const float threshold_now =
threshold ? static_cast<float>(threshold[bid]) : 0.f;
int top_num = TopPBeamTopK; int top_num = TopPBeamTopK;
float top_p_num = static_cast<float>(top_ps[bid]); float top_p_num = static_cast<float>(top_ps[bid]);
...@@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, ...@@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
float rand_top_p = curand_uniform(state + bid) * top_p_num; float rand_top_p = curand_uniform(state + bid) * top_p_num;
top_ps[bid] = (T)rand_top_p; top_ps[bid] = (T)rand_top_p;
float sum_prob = 0.0f; float sum_prob = 0.0f;
for (int i = 0; i < TopPBeamTopK; i++) { for (int i = 0; i < TopPBeamTopK; i++) {
sum_prob += static_cast<float>(beam_max[i].v); float val = static_cast<float>(beam_max[i].v);
sum_prob += val;
#ifdef DEBUG_TOPP #ifdef DEBUG_TOPP
printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n", printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n",
bid, bid,
...@@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, ...@@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
#endif #endif
if (sum_prob >= rand_top_p) { if (sum_prob >= rand_top_p) {
count_iter_begin[bid] += 1; count_iter_begin[bid] += 1;
out_id[bid] = (int64_t)beam_max[i].id; if (val < threshold_now) {
// don't sample low score token
int start_id = i == 0 ? 0 : i - 1;
for (int j = start_id; j >= 0; j--) {
float val_now = static_cast<float>(beam_max[j].v);
if (val_now >= threshold_now || j == 0) {
out_id[bid] = static_cast<int64_t>(beam_max[j].id);
out_val[bid] = beam_max[j].v;
break;
}
}
} else {
out_id[bid] = static_cast<int64_t>(beam_max[i].id);
out_val[bid] = beam_max[i].v; out_val[bid] = beam_max[i].v;
#ifdef DEBUG_TOPP }
printf(
"bi: %d, early stop id: %d\n", bid, static_cast<int>(out_id[bid]));
#endif
break; break;
} }
} }
...@@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) { ...@@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
} }
struct BlockPrefixCallbackOp { struct BlockPrefixCallbackOp {
// Running prefix
float running_total; float running_total;
// Constructor
__device__ BlockPrefixCallbackOp(float running_total) __device__ BlockPrefixCallbackOp(float running_total)
: running_total(running_total) {} : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ float operator()(float block_aggregate) { __device__ float operator()(float block_aggregate) {
float old_prefix = running_total; float old_prefix = running_total;
running_total += block_aggregate; running_total += block_aggregate;
...@@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp { ...@@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp {
} }
}; };
template <typename T>
__device__ T max_func(const T a, const T b) {
return a > b ? a : b;
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return max_func(a, b);
}
};
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T* sorted_probs, __global__ void topp_sampling(T* sorted_probs,
int64_t* sorted_id, int64_t* sorted_id,
T* out_val, T* out_val,
int64_t* out_id, int64_t* out_id,
const T* top_ps, const T* top_ps,
int p_num, const T* threshold,
int vocab_size, const uint64_t seed,
const int p_num,
const int vocab_size,
int* count_iter, int* count_iter,
int* count_iter_begin) { int* count_iter_begin) {
__shared__ int stop_shared; __shared__ int stop_shared;
...@@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs, ...@@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs,
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int warp_id = tid / 32; const int warp_id = tid / 32;
const float p_t = static_cast<float>(top_ps[bid]); const float p_t = static_cast<float>(top_ps[bid]);
const float threshold_now =
threshold ? static_cast<float>(threshold[bid]) : 0.f;
if (tid == 0) { if (tid == 0) {
stop_shared = 0; stop_shared = 0;
rand_p = p_t; rand_p = p_t;
...@@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs, ...@@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs,
} }
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan; typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockScan::TempStorage temp_storage; __shared__ typename BlockScan::TempStorage temp_storage;
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
__shared__ uint32_t selected_shared[NUM_WARPS]; __shared__ uint32_t selected_shared[NUM_WARPS];
int threshold_id = 0;
// Initialize running total // Initialize running total
BlockPrefixCallbackOp prefix_op(0); BlockPrefixCallbackOp prefix_op(0);
...@@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs, ...@@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads(); __syncthreads();
int offset = bid * vocab_size; int offset = bid * vocab_size;
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_activate = 0; int i_activate = 0;
float thread_offset = 0; float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) { for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count = float thread_count =
(i < vocab_size) ? static_cast<float>(sorted_probs[offset + i]) : 0.f; (i < vocab_size) ? static_cast<float>(sorted_probs[offset + i]) : 0.f;
if (i < vocab_size && thread_count >= threshold_now) {
threshold_id = i;
}
BlockScan(temp_storage) BlockScan(temp_storage)
.InclusiveSum(thread_count, thread_offset, prefix_op); .InclusiveSum(thread_count, thread_offset, prefix_op);
...@@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs, ...@@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads(); __syncthreads();
if (stop_shared == 0) { if (stop_shared == 0) {
if (tid == 0) { if (tid == 0) {
out_id[bid] = sorted_id[offset + vocab_size - 1]; out_id[bid] = sorted_id[offset];
out_val[bid] = sorted_probs[offset + vocab_size - 1]; out_val[bid] = sorted_probs[offset];
#ifdef DEBUG_TOPP
printf("stop_shared: %d, out_id: %d, out_val: %f\n",
static_cast<int>(stop_shared),
static_cast<int>(out_id[bid]),
static_cast<float>(out_val[bid]);
#endif
} }
return; return;
} }
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
bool skip = (selected_shared[warp_id] > 0) ? false : true; bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) { for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) { if (selected_shared[i] != 0) {
// If the previous has stopped, skip the current warp
skip = true; skip = true;
} }
} }
...@@ -499,19 +505,22 @@ __global__ void topp_sampling(T* sorted_probs, ...@@ -499,19 +505,22 @@ __global__ void topp_sampling(T* sorted_probs,
int active_lane_id = int active_lane_id =
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
if (lane_id == active_lane_id) { if (lane_id == active_lane_id) {
#ifdef DEBUG_TOPP float val = static_cast<float>(sorted_probs[offset + i_activate]);
printf( if (val < threshold_now) {
"active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate); // don't sample low score token
for (int i = 0; i < active_lane_id; i++) { int max_id =
printf("p %d, value: %f\n", BlockReduce(temp_storage_reduce).Reduce(threshold_id, MaxOp<int>());
i, curandStatePhilox4_32_10_t rng;
static_cast<float>(sorted_probs[offset + i])); curand_init(seed, tid, 0, &rng);
} int random_id = curand(&rng) % (max_id + 1);
#endif out_id[bid] = sorted_id[offset + random_id];
out_val[bid] = sorted_probs[offset + random_id];
} else {
out_id[bid] = sorted_id[offset + i_activate]; out_id[bid] = sorted_id[offset + i_activate];
out_val[bid] = sorted_probs[offset + i_activate]; out_val[bid] = sorted_probs[offset + i_activate];
} }
} }
}
} }
int GetBlockSize(int vocab_size) { int GetBlockSize(int vocab_size) {
...@@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) { ...@@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) {
} }
} }
template <typename T>
T* SafeGetTensorPtr(const DenseTensor& t) {
return const_cast<T*>(t.data<T>());
}
template <typename T>
T* SafeGetTensorPtr(const DenseTensor* t) {
return t ? SafeGetTensorPtr<T>(*t) : nullptr;
}
template <typename T>
T* SafeGetTensorPtr(const paddle::optional<DenseTensor>& t) {
return t ? SafeGetTensorPtr<T>(t.get()) : nullptr;
}
template <typename T, typename Context> template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx, void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& ps, const DenseTensor& ps,
const paddle::optional<DenseTensor>& threshold,
int random_seed, int random_seed,
DenseTensor* out, DenseTensor* out,
DenseTensor* ids) { DenseTensor* ids) {
...@@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx, ...@@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
dev_curand_states = dev_curand_states =
reinterpret_cast<curandState_t*>(curand_states_buf->ptr()); reinterpret_cast<curandState_t*>(curand_states_buf->ptr());
unsigned int seed = 0;
if (random_seed == -1) { if (random_seed == -1) {
srand((unsigned int)(time(NULL))); rand_r(&seed);
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand(), bs); setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, seed, bs);
} else { } else {
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, random_seed, bs); seed = random_seed;
} }
DenseTensor count_iter; DenseTensor count_iter;
...@@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx, ...@@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx,
dev_ctx.template Alloc<int>(&count_iter_begin); dev_ctx.template Alloc<int>(&count_iter_begin);
SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1); SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1);
T* threshold_data = SafeGetTensorPtr<T>(threshold);
constexpr int TopKMaxLength = 2; constexpr int TopKMaxLength = 2;
constexpr int TopPBeamTopK = 10; constexpr int TopPBeamTopK = 10;
switch (BlockSize) { switch (BlockSize) {
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim> KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
<<<bs, kBlockDim, 0, cu_stream>>>(x.data<T>(), <<<bs, kBlockDim, 0, cu_stream>>>(x.data<T>(),
threshold_data,
ps_now.data<T>(), ps_now.data<T>(),
ids_ptr, ids_ptr,
out_ptr, out_ptr,
...@@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx, ...@@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx,
out_ptr, out_ptr,
ids_ptr, ids_ptr,
ps_now.data<T>(), ps_now.data<T>(),
threshold_data,
seed,
p_num, p_num,
vocab_size, vocab_size,
count_iter.data<int>(), count_iter.data<int>(),
......
...@@ -22,6 +22,7 @@ template <typename T, typename Context> ...@@ -22,6 +22,7 @@ template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx, void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& ps, const DenseTensor& ps,
const paddle::optional<DenseTensor>& threshold,
int random_seed, int random_seed,
DenseTensor* out, DenseTensor* out,
DenseTensor* ids); DenseTensor* ids);
......
...@@ -53,6 +53,9 @@ def TopPProcess(probs, top_p): ...@@ -53,6 +53,9 @@ def TopPProcess(probs, top_p):
return next_scores, next_tokens return next_scores, next_tokens
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestTopPAPI(unittest.TestCase): class TestTopPAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.topp = 0.0 self.topp = 0.0
...@@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase): ...@@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase):
).reshape((-1, 1)) ).reshape((-1, 1))
# test case for basic test case 1 # test case for basic test case 1
paddle_result = paddle.top_p_sampling( paddle_result = paddle.top_p_sampling(
input_tensor, topp_tensor, self.seed input_tensor, topp_tensor, seed=self.seed
) )
ref_res = TopPProcess(input_tensor, self.topp) ref_res = TopPProcess(input_tensor, self.topp)
...@@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase): ...@@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase):
topp_tensor = paddle.static.data( topp_tensor = paddle.static.data(
name="topp", shape=[6, 1], dtype=self.dtype name="topp", shape=[6, 1], dtype=self.dtype
) )
result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed) result = paddle.top_p_sampling(
input_tensor, topp_tensor, seed=self.seed
)
ref_res = TopPProcess(input_tensor, self.topp) ref_res = TopPProcess(input_tensor, self.topp)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
input_data = np.random.rand(6, 1030).astype(self.dtype) input_data = np.random.rand(6, 1030).astype(self.dtype)
......
...@@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): ...@@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
return values, indices return values, indices
def top_p_sampling(x, ps, seed=None, name=None): def top_p_sampling(x, ps, threshold=None, seed=None, name=None):
""" """
Get the TopP scores and ids. Get the TopP scores and ids.
Args: Args:
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16. x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16. ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
threshold(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
seed(int, optional): the random seed, seed(int, optional): the random seed,
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None): ...@@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None):
seed = -1 seed = -1
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.top_p_sampling(x, ps, seed) return _C_ops.top_p_sampling(x, ps, threshold, seed)
inputs = {"x": [x], "ps": [ps]} inputs = {"x": x, "ps": ps, "threshold": threshold}
attrs = {"seed": seed} attrs = {"random_seed": seed}
helper = LayerHelper('top_p_sampling', **locals()) helper = LayerHelper('top_p_sampling', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None): ...@@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None):
helper.append_op( helper.append_op(
type='top_p_sampling', type='top_p_sampling',
inputs=inputs, inputs=inputs,
outputs={'out': [out], 'ids': [ids]}, outputs={'out': out, 'ids': ids},
attrs=attrs, attrs=attrs,
) )
return out, ids return out, ids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册