未验证 提交 e1545af4 编写于 作者: L lzy 提交者: GitHub

make top_p_sampling supports threshold (#55486)

* make top_p_sampling supports threshold

* delete __nv_bfloat16
上级 0252287e
......@@ -1922,13 +1922,14 @@
backward : thresholded_relu_grad
- op : top_p_sampling
args : (Tensor x, Tensor ps, int random_seed=-1)
args : (Tensor x, Tensor ps, Tensor threshold, int random_seed=-1)
output : Tensor (out), Tensor(ids)
infer_meta :
func : TopPSamplingInferMeta
kernel :
func : top_p_sampling
data_type : x
optional : threshold
- op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
......
......@@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
const MetaTensor& threshold,
int random_seed,
MetaTensor* out,
MetaTensor* ids) {
......
......@@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
const MetaTensor& threshold,
int random_seed,
MetaTensor* out,
MetaTensor* ids);
......
......@@ -41,11 +41,6 @@ struct DataTypeTraits<phi::dtype::float16> {
using DataType = half;
};
// template <>
// struct DataTypeTraits<phi::dtype::bfloat16> {
// using DataType = __nv_bfloat16;
// };
#define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
......@@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state,
const int bs) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
curand_init(seed + i, 0, 0, &state[i]);
curand_init(seed, i, 0, &state[i]);
}
}
......@@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopK(const T* src,
const T* threshold,
T* top_ps,
int64_t* out_id, // topk id
T* out_val, // topk val
......@@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
const int wid = tid / 32;
const int lane = tid % 32;
const int bid = blockIdx.x;
const float threshold_now =
threshold ? static_cast<float>(threshold[bid]) : 0.f;
int top_num = TopPBeamTopK;
float top_p_num = static_cast<float>(top_ps[bid]);
......@@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
float rand_top_p = curand_uniform(state + bid) * top_p_num;
top_ps[bid] = (T)rand_top_p;
float sum_prob = 0.0f;
for (int i = 0; i < TopPBeamTopK; i++) {
sum_prob += static_cast<float>(beam_max[i].v);
float val = static_cast<float>(beam_max[i].v);
sum_prob += val;
#ifdef DEBUG_TOPP
printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n",
bid,
......@@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
#endif
if (sum_prob >= rand_top_p) {
count_iter_begin[bid] += 1;
out_id[bid] = (int64_t)beam_max[i].id;
out_val[bid] = beam_max[i].v;
#ifdef DEBUG_TOPP
printf(
"bi: %d, early stop id: %d\n", bid, static_cast<int>(out_id[bid]));
#endif
if (val < threshold_now) {
// don't sample low score token
int start_id = i == 0 ? 0 : i - 1;
for (int j = start_id; j >= 0; j--) {
float val_now = static_cast<float>(beam_max[j].v);
if (val_now >= threshold_now || j == 0) {
out_id[bid] = static_cast<int64_t>(beam_max[j].id);
out_val[bid] = beam_max[j].v;
break;
}
}
} else {
out_id[bid] = static_cast<int64_t>(beam_max[i].id);
out_val[bid] = beam_max[i].v;
}
break;
}
}
......@@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
struct BlockPrefixCallbackOp {
// Running prefix
float running_total;
// Constructor
__device__ BlockPrefixCallbackOp(float running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ float operator()(float block_aggregate) {
float old_prefix = running_total;
running_total += block_aggregate;
......@@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp {
}
};
template <typename T>
__device__ T max_func(const T a, const T b) {
return a > b ? a : b;
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return max_func(a, b);
}
};
template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T* sorted_probs,
int64_t* sorted_id,
T* out_val,
int64_t* out_id,
const T* top_ps,
int p_num,
int vocab_size,
const T* threshold,
const uint64_t seed,
const int p_num,
const int vocab_size,
int* count_iter,
int* count_iter_begin) {
__shared__ int stop_shared;
......@@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs,
const int lane_id = tid % 32;
const int warp_id = tid / 32;
const float p_t = static_cast<float>(top_ps[bid]);
const float threshold_now =
threshold ? static_cast<float>(threshold[bid]) : 0.f;
if (tid == 0) {
stop_shared = 0;
rand_p = p_t;
......@@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs,
}
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
__shared__ uint32_t selected_shared[NUM_WARPS];
int threshold_id = 0;
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
......@@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads();
int offset = bid * vocab_size;
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_activate = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count =
(i < vocab_size) ? static_cast<float>(sorted_probs[offset + i]) : 0.f;
if (i < vocab_size && thread_count >= threshold_now) {
threshold_id = i;
}
BlockScan(temp_storage)
.InclusiveSum(thread_count, thread_offset, prefix_op);
......@@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads();
if (stop_shared == 0) {
if (tid == 0) {
out_id[bid] = sorted_id[offset + vocab_size - 1];
out_val[bid] = sorted_probs[offset + vocab_size - 1];
#ifdef DEBUG_TOPP
printf("stop_shared: %d, out_id: %d, out_val: %f\n",
static_cast<int>(stop_shared),
static_cast<int>(out_id[bid]),
static_cast<float>(out_val[bid]);
#endif
out_id[bid] = sorted_id[offset];
out_val[bid] = sorted_probs[offset];
}
return;
}
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) {
// If the previous has stopped, skip the current warp
skip = true;
}
}
......@@ -499,17 +505,20 @@ __global__ void topp_sampling(T* sorted_probs,
int active_lane_id =
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
if (lane_id == active_lane_id) {
#ifdef DEBUG_TOPP
printf(
"active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate);
for (int i = 0; i < active_lane_id; i++) {
printf("p %d, value: %f\n",
i,
static_cast<float>(sorted_probs[offset + i]));
float val = static_cast<float>(sorted_probs[offset + i_activate]);
if (val < threshold_now) {
// don't sample low score token
int max_id =
BlockReduce(temp_storage_reduce).Reduce(threshold_id, MaxOp<int>());
curandStatePhilox4_32_10_t rng;
curand_init(seed, tid, 0, &rng);
int random_id = curand(&rng) % (max_id + 1);
out_id[bid] = sorted_id[offset + random_id];
out_val[bid] = sorted_probs[offset + random_id];
} else {
out_id[bid] = sorted_id[offset + i_activate];
out_val[bid] = sorted_probs[offset + i_activate];
}
#endif
out_id[bid] = sorted_id[offset + i_activate];
out_val[bid] = sorted_probs[offset + i_activate];
}
}
}
......@@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) {
}
}
template <typename T>
T* SafeGetTensorPtr(const DenseTensor& t) {
return const_cast<T*>(t.data<T>());
}
template <typename T>
T* SafeGetTensorPtr(const DenseTensor* t) {
return t ? SafeGetTensorPtr<T>(*t) : nullptr;
}
template <typename T>
T* SafeGetTensorPtr(const paddle::optional<DenseTensor>& t) {
return t ? SafeGetTensorPtr<T>(t.get()) : nullptr;
}
template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& ps,
const paddle::optional<DenseTensor>& threshold,
int random_seed,
DenseTensor* out,
DenseTensor* ids) {
......@@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
dev_curand_states =
reinterpret_cast<curandState_t*>(curand_states_buf->ptr());
unsigned int seed = 0;
if (random_seed == -1) {
srand((unsigned int)(time(NULL)));
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand(), bs);
rand_r(&seed);
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, seed, bs);
} else {
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, random_seed, bs);
seed = random_seed;
}
DenseTensor count_iter;
......@@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx,
dev_ctx.template Alloc<int>(&count_iter_begin);
SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1);
T* threshold_data = SafeGetTensorPtr<T>(threshold);
constexpr int TopKMaxLength = 2;
constexpr int TopPBeamTopK = 10;
switch (BlockSize) {
FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
<<<bs, kBlockDim, 0, cu_stream>>>(x.data<T>(),
threshold_data,
ps_now.data<T>(),
ids_ptr,
out_ptr,
......@@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx,
out_ptr,
ids_ptr,
ps_now.data<T>(),
threshold_data,
seed,
p_num,
vocab_size,
count_iter.data<int>(),
......
......@@ -22,6 +22,7 @@ template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& ps,
const paddle::optional<DenseTensor>& threshold,
int random_seed,
DenseTensor* out,
DenseTensor* ids);
......
......@@ -53,6 +53,9 @@ def TopPProcess(probs, top_p):
return next_scores, next_tokens
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestTopPAPI(unittest.TestCase):
def setUp(self):
self.topp = 0.0
......@@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase):
).reshape((-1, 1))
# test case for basic test case 1
paddle_result = paddle.top_p_sampling(
input_tensor, topp_tensor, self.seed
input_tensor, topp_tensor, seed=self.seed
)
ref_res = TopPProcess(input_tensor, self.topp)
......@@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase):
topp_tensor = paddle.static.data(
name="topp", shape=[6, 1], dtype=self.dtype
)
result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed)
result = paddle.top_p_sampling(
input_tensor, topp_tensor, seed=self.seed
)
ref_res = TopPProcess(input_tensor, self.topp)
exe = paddle.static.Executor(place)
input_data = np.random.rand(6, 1030).astype(self.dtype)
......
......@@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
return values, indices
def top_p_sampling(x, ps, seed=None, name=None):
def top_p_sampling(x, ps, threshold=None, seed=None, name=None):
"""
Get the TopP scores and ids.
Args:
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
threshold(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
seed(int, optional): the random seed,
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
......@@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None):
seed = -1
if in_dygraph_mode():
return _C_ops.top_p_sampling(x, ps, seed)
return _C_ops.top_p_sampling(x, ps, threshold, seed)
inputs = {"x": [x], "ps": [ps]}
attrs = {"seed": seed}
inputs = {"x": x, "ps": ps, "threshold": threshold}
attrs = {"random_seed": seed}
helper = LayerHelper('top_p_sampling', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None):
helper.append_op(
type='top_p_sampling',
inputs=inputs,
outputs={'out': [out], 'ids': [ids]},
outputs={'out': out, 'ids': ids},
attrs=attrs,
)
return out, ids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册