未验证 提交 f9c9dc29 编写于 作者: L lzy 提交者: GitHub

add top_p_sampling (#54127)

上级 435560f0
...@@ -1921,6 +1921,15 @@ ...@@ -1921,6 +1921,15 @@
func : thresholded_relu func : thresholded_relu
backward : thresholded_relu_grad backward : thresholded_relu_grad
- op : top_p_sampling
args : (Tensor x, Tensor ps, int random_seed=-1)
output : Tensor (out), Tensor(ids)
infer_meta :
func : TopPSamplingInferMeta
kernel :
func : top_p_sampling
data_type : x
- op : topk - op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true) args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
output : Tensor(out), Tensor(indices) output : Tensor(out), Tensor(indices)
......
...@@ -2742,6 +2742,26 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -2742,6 +2742,26 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y); out->share_lod(y);
} }
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
int random_seed,
MetaTensor* out,
MetaTensor* ids) {
auto x_dims = x.dims();
auto ps_dims = ps.dims();
PADDLE_ENFORCE_EQ(x_dims[0],
ps_dims[0],
phi::errors::InvalidArgument(
"The x_dims[0] must be equal to ps_dims[0] "
"But received x_dims[0] = %d and ps_dims[0] = %d.",
x_dims[0],
ps_dims[0]));
ids->set_dims(phi::make_ddim({x_dims[0], 1}));
ids->set_dtype(DataType::INT64);
out->set_dims(phi::make_ddim({x_dims[0], 1}));
out->set_dtype(x.dtype());
}
void LstsqInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const Scalar& rcond, const Scalar& rcond,
......
...@@ -428,6 +428,12 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -428,6 +428,12 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular, bool unitriangular,
MetaTensor* out); MetaTensor* out);
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
int random_seed,
MetaTensor* out,
MetaTensor* ids);
void LstsqInferMeta(const MetaTensor& x, void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const Scalar& rcond, const Scalar& rcond,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/top_p_sampling_kernel.h"
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
// #define DEBUG_TOPP
namespace phi {
template <typename T>
struct DataTypeTraits {
using DataType = T;
};
template <>
struct DataTypeTraits<phi::dtype::float16> {
using DataType = half;
};
// template <>
// struct DataTypeTraits<phi::dtype::bfloat16> {
// using DataType = __nv_bfloat16;
// };
#define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
namespace ops = paddle::operators;
struct SegmentOffsetIter {
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
__host__ __device__ __forceinline__ int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) {
v = value;
id = id;
}
__device__ __forceinline__ void operator=(const Pair<T>& in) {
v = in.v;
id = in.id;
}
__device__ __forceinline__ bool operator<(const T value) const {
return (static_cast<float>(v) < static_cast<float>(value));
}
__device__ __forceinline__ bool operator>(const T value) const {
return (static_cast<float>(v) > static_cast<float>(value));
}
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
return (static_cast<float>(v) < static_cast<float>(in.v)) ||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
(id > in.id));
}
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
return (static_cast<float>(v) > static_cast<float>(in.v)) ||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
(id < in.id));
}
T v;
int id;
};
inline int div_up(int a, int n) { return (a + n - 1) / n; }
__global__ void setup_kernel(curandState_t* state,
const uint64_t seed,
const int bs) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
curand_init(seed + i, 0, 0, &state[i]);
}
}
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[],
const Pair<T>& p,
int beam_size) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(
Pair<T> topk[], const T* src, int idx, int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
const T* src,
int idx,
int dim,
const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
int* beam,
int beam_size,
const T* src,
bool* firstStep,
bool* is_empty,
Pair<T>* max,
int dim,
const int tid) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
*firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(std::numeric_limits<T>::min(), -1);
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(
topk + MaxLength - *beam, src, tid, dim, *max, length);
}
}
*max = topk[MaxLength - 1];
if ((*max).id == -1) *is_empty = true;
*beam = 0;
}
}
template <typename T>
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (static_cast<float>(input.v) < static_cast<float>(tmp_val)) {
input.v = tmp_val;
input.id = tmp_id;
}
}
return input;
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
Pair<T> topk[],
Pair<T> beam_max[],
int* beam,
int* k,
int* count,
const int tid,
const int wid,
const int lane) {
while (true) {
__syncthreads();
Pair<T> input_now = topk[0];
input_now = WarpReduce(input_now);
if (lane == 0) {
shared_max[wid] = input_now;
}
__syncthreads();
input_now = (tid < BlockSize / 32)
? shared_max[lane]
: Pair<T>(std::numeric_limits<T>::min(), -1);
if (wid == 0) {
input_now = WarpReduce(input_now);
if (lane == 0) shared_max[0] = input_now;
}
__syncthreads();
if (tid == 0) {
beam_max[*count] = shared_max[0];
(*count)++;
}
int tid_max = shared_max[0].id % BlockSize;
if (tid == tid_max) {
(*beam)++;
}
if (--(*k) == 0) break;
__syncthreads();
if (tid == tid_max) {
if (*beam < MaxLength) {
topk[0] = topk[*beam];
}
}
if (MaxLength < 5) {
if (*beam >= MaxLength) break;
} else {
unsigned mask = 0u;
mask = __ballot_sync(FINAL_MASK, true);
if (tid_max / 32 == wid) {
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength)
break;
}
}
}
}
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopK(const T* src,
T* top_ps,
int64_t* out_id, // topk id
T* out_val, // topk val
int vocab_size,
curandState_t* state,
int* count_iter,
int* count_iter_begin) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
const int bid = blockIdx.x;
int top_num = TopPBeamTopK;
float top_p_num = static_cast<float>(top_ps[bid]);
__shared__ Pair<T> shared_max[BlockSize / 32];
__shared__ Pair<T> beam_max[TopPBeamTopK];
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
__shared__ int count;
if (tid == 0) {
count = 0;
}
for (int j = 0; j < MaxLength; j++) {
topk[j].set(std::numeric_limits<T>::min(), -1);
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
src + bid * vocab_size,
&firststep,
&is_empty,
&max,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
}
if (tid == 0) {
count_iter_begin[bid] = count_iter[bid];
float rand_top_p = curand_uniform(state + bid) * top_p_num;
top_ps[bid] = (T)rand_top_p;
float sum_prob = 0.0f;
for (int i = 0; i < TopPBeamTopK; i++) {
sum_prob += static_cast<float>(beam_max[i].v);
#ifdef DEBUG_TOPP
printf("bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f\n",
bid,
top_p_num,
rand_top_p,
sum_prob);
#endif
if (sum_prob >= rand_top_p) {
count_iter_begin[bid] += 1;
out_id[bid] = (int64_t)beam_max[i].id;
out_val[bid] = beam_max[i].v;
#ifdef DEBUG_TOPP
printf(
"bi: %d, early stop id: %d\n", bid, static_cast<int>(out_id[bid]));
#endif
break;
}
}
}
}
__global__ void SetCountIter(int* count_iter, int num) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int idx = bid * blockDim.x + tid;
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
count_iter[i] = i;
}
}
template <typename T>
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (T j = row_id; j < num_rows; j += gridDim.x) {
for (T i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
struct BlockPrefixCallbackOp {
float running_total;
__device__ BlockPrefixCallbackOp(float running_total)
: running_total(running_total) {}
__device__ float operator()(float block_aggregate) {
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T* sorted_probs,
int64_t* sorted_id,
T* out_val,
int64_t* out_id,
const T* top_ps,
int p_num,
int vocab_size,
int* count_iter,
int* count_iter_begin) {
__shared__ int stop_shared;
__shared__ float rand_p;
const int tid = threadIdx.x;
const int bid = blockIdx.x;
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
const int lane_id = tid % 32;
const int warp_id = tid / 32;
const float p_t = static_cast<float>(top_ps[bid]);
if (tid == 0) {
stop_shared = 0;
rand_p = p_t;
#ifdef DEBUG_TOPP
printf("bi: %d, p: %f\n", bid, rand_p);
#endif
}
if (count_iter_begin[bid] == count_iter[bid + 1]) {
// topk
return;
}
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ uint32_t selected_shared[NUM_WARPS];
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
if (lane_id == 0) {
selected_shared[warp_id] = 0;
}
__syncthreads();
int offset = bid * vocab_size;
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_activate = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count =
(i < vocab_size) ? static_cast<float>(sorted_probs[offset + i]) : 0.f;
BlockScan(temp_storage)
.InclusiveSum(thread_count, thread_offset, prefix_op);
uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset);
i_activate = i;
if (activate_mask != 0) {
if (lane_id == 0) {
atomicAdd(&stop_shared, 1);
selected_shared[warp_id] = activate_mask;
}
}
__syncthreads();
if (stop_shared > 0) {
break;
}
}
__syncthreads();
if (stop_shared == 0) {
if (tid == 0) {
out_id[bid] = sorted_id[offset + vocab_size - 1];
out_val[bid] = sorted_probs[offset + vocab_size - 1];
#ifdef DEBUG_TOPP
printf("stop_shared: %d, out_id: %d, out_val: %f\n",
static_cast<int>(stop_shared),
static_cast<int>(out_id[bid]),
static_cast<float>(out_val[bid]);
#endif
}
return;
}
#ifdef DEBUG_TOPP
if (tid == 0) {
printf(
"first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: "
"%d\n",
static_cast<float>(sorted_probs[offset]),
static_cast<float>(sorted_probs[offset + 1]),
static_cast<int>(sorted_id[offset]),
static_cast<int>(sorted_id[offset + 1]);
}
#endif
bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) {
skip = true;
}
}
if (!skip) {
int active_lane_id =
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
if (lane_id == active_lane_id) {
#ifdef DEBUG_TOPP
printf(
"active_lane_id: %d, i_activate: %d.\n", active_lane_id, i_activate);
for (int i = 0; i < active_lane_id; i++) {
printf("p %d, value: %f\n",
i,
static_cast<float>(sorted_probs[offset + i]));
}
#endif
out_id[bid] = sorted_id[offset + i_activate];
out_val[bid] = sorted_probs[offset + i_activate];
}
}
}
int GetBlockSize(int vocab_size) {
if (vocab_size > 512) {
return 1024;
} else if (vocab_size > 256) {
return 512;
} else if (vocab_size > 128) {
return 256;
} else if (vocab_size > 64) {
return 128;
} else {
return 64;
}
}
__global__ void set_sorted_num(int* need_sorted_num, int bs) {
*need_sorted_num = bs;
}
template <typename T>
__global__ void print_kernel(T* input, int size) {
printf("[");
for (int i = 0; i < size; i++) {
if (i != size - 1) {
printf("%f, ", static_cast<float>(input[i]));
} else {
printf("%f]\n", static_cast<float>(input[i]));
}
}
}
template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& ps,
int random_seed,
DenseTensor* out,
DenseTensor* ids) {
typedef DataTypeTraits<T> traits_;
typedef typename traits_::DataType DataType_;
auto cu_stream = dev_ctx.stream();
const auto* input = &x;
// get the input dims
const auto& in_dims = input->dims();
int p_num = ps.numel();
int bs = in_dims[0];
int vocab_size = in_dims[1];
T* out_ptr = dev_ctx.template Alloc<T>(out);
int64_t* ids_ptr = dev_ctx.template Alloc<int64_t>(ids);
DenseTensor ps_now;
ps_now.Resize(phi::make_ddim({bs, 1}));
dev_ctx.template Alloc<T>(&ps_now);
phi::Copy(dev_ctx, ps, dev_ctx.GetPlace(), false, &ps_now);
DenseTensor inds_input;
inds_input.Resize(phi::make_ddim({bs, vocab_size}));
dev_ctx.template Alloc<int64_t>(&inds_input);
DenseTensor sorted_out;
sorted_out.Resize(phi::make_ddim({bs, vocab_size}));
dev_ctx.template Alloc<T>(&sorted_out);
DenseTensor sorted_id;
sorted_id.Resize(phi::make_ddim({bs, vocab_size}));
dev_ctx.template Alloc<int64_t>(&sorted_id);
int BlockSize = GetBlockSize(vocab_size);
switch (BlockSize) {
FIXED_BLOCK_DIM(FillIndex<int64_t><<<bs, kBlockDim, 0, cu_stream>>>(
inds_input.data<int64_t>(), bs, vocab_size));
default:
PD_THROW("the input data shape has error in the FillIndex kernel.");
}
curandState_t* dev_curand_states;
phi::Allocator::AllocationPtr curand_states_buf{nullptr};
curand_states_buf = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
bs * sizeof(curandState_t),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
dev_curand_states =
reinterpret_cast<curandState_t*>(curand_states_buf->ptr());
if (random_seed == -1) {
srand((unsigned int)(time(NULL)));
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, rand(), bs);
} else {
setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, random_seed, bs);
}
DenseTensor count_iter;
count_iter.Resize(phi::make_ddim({bs + 1}));
dev_ctx.template Alloc<int>(&count_iter);
DenseTensor count_iter_begin;
count_iter_begin.Resize(phi::make_ddim({bs}));
dev_ctx.template Alloc<int>(&count_iter_begin);
SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data<int>(), bs + 1);
constexpr int TopKMaxLength = 2;
constexpr int TopPBeamTopK = 10;
switch (BlockSize) {
FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
<<<bs, kBlockDim, 0, cu_stream>>>(x.data<T>(),
ps_now.data<T>(),
ids_ptr,
out_ptr,
vocab_size,
dev_curand_states,
count_iter.data<int>(),
count_iter_begin.data<int>()));
default:
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
}
size_t temp_storage_bytes = 0;
cub::TransformInputIterator<int, SegmentOffsetIter, int*>
segment_offsets_t_begin(count_iter_begin.data<int>(),
SegmentOffsetIter(vocab_size));
cub::TransformInputIterator<int, SegmentOffsetIter, int*>
segment_offsets_t_end(count_iter.data<int>(),
SegmentOffsetIter(vocab_size));
cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
reinterpret_cast<DataType_*>(const_cast<T*>(x.data<T>())),
reinterpret_cast<DataType_*>(const_cast<T*>(sorted_out.data<T>())),
inds_input.data<int64_t>(),
sorted_id.data<int64_t>(),
vocab_size * bs,
bs,
segment_offsets_t_begin,
segment_offsets_t_end + 1,
0,
sizeof(T) * 8,
cu_stream);
temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
int64_t temp_size = temp_storage_bytes;
DenseTensor temp_storage;
temp_storage.Resize(phi::make_ddim({temp_size}));
dev_ctx.template Alloc<uint8_t>(&temp_storage);
cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(),
temp_storage_bytes,
reinterpret_cast<DataType_*>(const_cast<T*>(x.data<T>())),
reinterpret_cast<DataType_*>(const_cast<T*>(sorted_out.data<T>())),
inds_input.data<int64_t>(),
sorted_id.data<int64_t>(),
vocab_size * bs,
bs,
segment_offsets_t_begin,
segment_offsets_t_end + 1,
0,
sizeof(T) * 8,
cu_stream);
switch (BlockSize) {
FIXED_BLOCK_DIM(
topp_sampling<T, kBlockDim>
<<<bs, kBlockDim, 0, cu_stream>>>(sorted_out.data<T>(),
sorted_id.data<int64_t>(),
out_ptr,
ids_ptr,
ps_now.data<T>(),
p_num,
vocab_size,
count_iter.data<int>(),
count_iter_begin.data<int>()));
default:
PD_THROW("the input data shape has error in the topp_sampling kernel.");
}
return;
}
} // namespace phi
PD_REGISTER_KERNEL(top_p_sampling,
GPU,
ALL_LAYOUT,
phi::TopPSamplingKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& ps,
int random_seed,
DenseTensor* out,
DenseTensor* ids);
} // namespace phi
...@@ -320,6 +320,7 @@ from .tensor.search import nonzero # noqa: F401 ...@@ -320,6 +320,7 @@ from .tensor.search import nonzero # noqa: F401
from .tensor.search import sort # noqa: F401 from .tensor.search import sort # noqa: F401
from .tensor.search import kthvalue # noqa: F401 from .tensor.search import kthvalue # noqa: F401
from .tensor.search import mode # noqa: F401 from .tensor.search import mode # noqa: F401
from .tensor.search import top_p_sampling # noqa: F401
from .tensor.to_string import set_printoptions # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401
...@@ -542,6 +543,7 @@ __all__ = [ # noqa ...@@ -542,6 +543,7 @@ __all__ = [ # noqa
'zeros_like', 'zeros_like',
'maximum', 'maximum',
'topk', 'topk',
'top_p_sampling',
'index_select', 'index_select',
'CPUPlace', 'CPUPlace',
'matmul', 'matmul',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
def TopPProcess(probs, top_p):
sorted_probs = paddle.sort(probs, descending=True)
sorted_indices = paddle.argsort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
# Remove tokens with cumulative probs above the top_p, But keep at
# least min_tokens_to_keep tokens
sorted_indices_to_remove = cumulative_probs > top_p
# Keep the first token
sorted_indices_to_remove = paddle.cast(
sorted_indices_to_remove, dtype='int64'
)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# Scatter sorted tensors to original indexing
sorted_indices = (
sorted_indices
+ paddle.arange(probs.shape[0]).unsqueeze(-1) * probs.shape[-1]
)
condition = paddle.scatter(
sorted_indices_to_remove.flatten(),
sorted_indices.flatten(),
sorted_indices_to_remove.flatten(),
)
condition = paddle.cast(condition, 'bool').reshape(probs.shape)
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
next_tokens = paddle.multinomial(probs)
next_scores = paddle.index_sample(probs, next_tokens)
return next_scores, next_tokens
class TestTopPAPI(unittest.TestCase):
def setUp(self):
self.topp = 0.0
self.seed = 6688
self.batch_size = 3
self.vocab_size = 10000
self.dtype = "float32"
self.input_data = np.random.rand(self.batch_size, self.vocab_size)
def run_dygraph(self, place):
with paddle.fluid.dygraph.guard(place):
input_tensor = paddle.to_tensor(self.input_data, self.dtype)
topp_tensor = paddle.to_tensor(
[
self.topp,
]
* self.batch_size,
self.dtype,
).reshape((-1, 1))
# test case for basic test case 1
paddle_result = paddle.top_p_sampling(
input_tensor, topp_tensor, self.seed
)
ref_res = TopPProcess(input_tensor, self.topp)
np.testing.assert_allclose(
paddle_result[0].numpy(), ref_res[0].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
paddle_result[1].numpy().flatten(),
ref_res[1].numpy().flatten(),
rtol=0,
)
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input_tensor = paddle.static.data(
name="x", shape=[6, 1030], dtype=self.dtype
)
topp_tensor = paddle.static.data(
name="topp", shape=[6, 1], dtype=self.dtype
)
result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed)
ref_res = TopPProcess(input_tensor, self.topp)
exe = paddle.static.Executor(place)
input_data = np.random.rand(6, 1030).astype(self.dtype)
paddle_result = exe.run(
feed={
"x": input_data,
"topp": np.array(
[
self.topp,
]
* 6
).astype(self.dtype),
},
fetch_list=[
result[0],
result[1],
ref_res[0],
ref_res[1],
],
)
np.testing.assert_allclose(
paddle_result[0], paddle_result[2], rtol=1e-05
)
np.testing.assert_allclose(
paddle_result[1], paddle_result[3], rtol=1e-05
)
def test_cases(self):
places = [core.CUDAPlace(0)]
for place in places:
self.run_dygraph(place)
self.run_static(place)
if __name__ == "__main__":
unittest.main()
...@@ -278,6 +278,7 @@ from .search import index_sample # noqa: F401 ...@@ -278,6 +278,7 @@ from .search import index_sample # noqa: F401
from .search import masked_select # noqa: F401 from .search import masked_select # noqa: F401
from .search import kthvalue # noqa: F401 from .search import kthvalue # noqa: F401
from .search import mode # noqa: F401 from .search import mode # noqa: F401
from .search import top_p_sampling
from .stat import mean # noqa: F401 from .stat import mean # noqa: F401
from .stat import std # noqa: F401 from .stat import std # noqa: F401
...@@ -468,6 +469,7 @@ tensor_method_func = [ # noqa ...@@ -468,6 +469,7 @@ tensor_method_func = [ # noqa
'argsort', 'argsort',
'masked_select', 'masked_select',
'topk', 'topk',
'top_p_sampling',
'where', 'where',
'index_select', 'index_select',
'nonzero', 'nonzero',
......
...@@ -1129,3 +1129,38 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): ...@@ -1129,3 +1129,38 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
) )
indices.stop_gradient = True indices.stop_gradient = True
return values, indices return values, indices
def top_p_sampling(x, ps, seed=None, name=None):
"""
Get the TopP scores and ids.
Args:
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
seed(int, optional): the random seed,
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
"""
if seed is None:
seed = -1
if in_dygraph_mode():
return _C_ops.top_p_sampling(x, ps, seed)
inputs = {"x": [x], "ps": [ps]}
attrs = {"seed": seed}
helper = LayerHelper('top_p_sampling', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
ids = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type='top_p_sampling',
inputs=inputs,
outputs={'out': [out], 'ids': [ids]},
attrs=attrs,
)
return out, ids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册