Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e1545af4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2293
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e1545af4
编写于
7月 18, 2023
作者:
L
lzy
提交者:
GitHub
7月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make top_p_sampling supports threshold (#55486)
* make top_p_sampling supports threshold * delete __nv_bfloat16
上级
0252287e
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
110 addition
and
69 deletion
+110
-69
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+2
-1
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+1
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+1
-0
paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
+92
-61
paddle/phi/kernels/top_p_sampling_kernel.h
paddle/phi/kernels/top_p_sampling_kernel.h
+1
-0
python/paddle/fluid/tests/unittests/test_top_p_sampling.py
python/paddle/fluid/tests/unittests/test_top_p_sampling.py
+7
-2
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+6
-5
未找到文件。
paddle/phi/api/yaml/ops.yaml
浏览文件 @
e1545af4
...
@@ -1922,13 +1922,14 @@
...
@@ -1922,13 +1922,14 @@
backward
:
thresholded_relu_grad
backward
:
thresholded_relu_grad
-
op
:
top_p_sampling
-
op
:
top_p_sampling
args
:
(Tensor x, Tensor ps, int random_seed=-1)
args
:
(Tensor x, Tensor ps,
Tensor threshold,
int random_seed=-1)
output
:
Tensor (out), Tensor(ids)
output
:
Tensor (out), Tensor(ids)
infer_meta
:
infer_meta
:
func
:
TopPSamplingInferMeta
func
:
TopPSamplingInferMeta
kernel
:
kernel
:
func
:
top_p_sampling
func
:
top_p_sampling
data_type
:
x
data_type
:
x
optional
:
threshold
-
op
:
topk
-
op
:
topk
args
:
(Tensor x, Scalar(int) k = 1, int axis = -1, bool largest =
true
, bool sorted =
true
)
args
:
(Tensor x, Scalar(int) k = 1, int axis = -1, bool largest =
true
, bool sorted =
true
)
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
e1545af4
...
@@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
...
@@ -2744,6 +2744,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void
TopPSamplingInferMeta
(
const
MetaTensor
&
x
,
void
TopPSamplingInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
ps
,
const
MetaTensor
&
ps
,
const
MetaTensor
&
threshold
,
int
random_seed
,
int
random_seed
,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaTensor
*
ids
)
{
MetaTensor
*
ids
)
{
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
e1545af4
...
@@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
...
@@ -430,6 +430,7 @@ void TriangularSolveInferMeta(const MetaTensor& x,
void
TopPSamplingInferMeta
(
const
MetaTensor
&
x
,
void
TopPSamplingInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
ps
,
const
MetaTensor
&
ps
,
const
MetaTensor
&
threshold
,
int
random_seed
,
int
random_seed
,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaTensor
*
ids
);
MetaTensor
*
ids
);
...
...
paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
浏览文件 @
e1545af4
...
@@ -41,11 +41,6 @@ struct DataTypeTraits<phi::dtype::float16> {
...
@@ -41,11 +41,6 @@ struct DataTypeTraits<phi::dtype::float16> {
using
DataType
=
half
;
using
DataType
=
half
;
};
};
// template <>
// struct DataTypeTraits<phi::dtype::bfloat16> {
// using DataType = __nv_bfloat16;
// };
#define FINAL_MASK 0xFFFFFFFF
#define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
...
@@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state,
...
@@ -119,7 +114,7 @@ __global__ void setup_kernel(curandState_t* state,
const
int
bs
)
{
const
int
bs
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
idx
;
i
<
bs
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
idx
;
i
<
bs
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
curand_init
(
seed
+
i
,
0
,
0
,
&
state
[
i
]);
curand_init
(
seed
,
i
,
0
,
&
state
[
i
]);
}
}
}
}
...
@@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
...
@@ -278,6 +273,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
template
<
typename
T
,
int
MaxLength
,
int
TopPBeamTopK
,
int
BlockSize
>
template
<
typename
T
,
int
MaxLength
,
int
TopPBeamTopK
,
int
BlockSize
>
__global__
void
KeMatrixTopPBeamTopK
(
const
T
*
src
,
__global__
void
KeMatrixTopPBeamTopK
(
const
T
*
src
,
const
T
*
threshold
,
T
*
top_ps
,
T
*
top_ps
,
int64_t
*
out_id
,
// topk id
int64_t
*
out_id
,
// topk id
T
*
out_val
,
// topk val
T
*
out_val
,
// topk val
...
@@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
...
@@ -289,6 +285,8 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
const
int
wid
=
tid
/
32
;
const
int
wid
=
tid
/
32
;
const
int
lane
=
tid
%
32
;
const
int
lane
=
tid
%
32
;
const
int
bid
=
blockIdx
.
x
;
const
int
bid
=
blockIdx
.
x
;
const
float
threshold_now
=
threshold
?
static_cast
<
float
>
(
threshold
[
bid
])
:
0.
f
;
int
top_num
=
TopPBeamTopK
;
int
top_num
=
TopPBeamTopK
;
float
top_p_num
=
static_cast
<
float
>
(
top_ps
[
bid
]);
float
top_p_num
=
static_cast
<
float
>
(
top_ps
[
bid
]);
...
@@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
...
@@ -329,8 +327,10 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
float
rand_top_p
=
curand_uniform
(
state
+
bid
)
*
top_p_num
;
float
rand_top_p
=
curand_uniform
(
state
+
bid
)
*
top_p_num
;
top_ps
[
bid
]
=
(
T
)
rand_top_p
;
top_ps
[
bid
]
=
(
T
)
rand_top_p
;
float
sum_prob
=
0.0
f
;
float
sum_prob
=
0.0
f
;
for
(
int
i
=
0
;
i
<
TopPBeamTopK
;
i
++
)
{
for
(
int
i
=
0
;
i
<
TopPBeamTopK
;
i
++
)
{
sum_prob
+=
static_cast
<
float
>
(
beam_max
[
i
].
v
);
float
val
=
static_cast
<
float
>
(
beam_max
[
i
].
v
);
sum_prob
+=
val
;
#ifdef DEBUG_TOPP
#ifdef DEBUG_TOPP
printf
(
"bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f
\n
"
,
printf
(
"bi: %d, top_p: %f, rand_top_p: %f, sum_prob: %f
\n
"
,
bid
,
bid
,
...
@@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
...
@@ -340,12 +340,21 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
#endif
#endif
if
(
sum_prob
>=
rand_top_p
)
{
if
(
sum_prob
>=
rand_top_p
)
{
count_iter_begin
[
bid
]
+=
1
;
count_iter_begin
[
bid
]
+=
1
;
out_id
[
bid
]
=
(
int64_t
)
beam_max
[
i
].
id
;
if
(
val
<
threshold_now
)
{
// don't sample low score token
int
start_id
=
i
==
0
?
0
:
i
-
1
;
for
(
int
j
=
start_id
;
j
>=
0
;
j
--
)
{
float
val_now
=
static_cast
<
float
>
(
beam_max
[
j
].
v
);
if
(
val_now
>=
threshold_now
||
j
==
0
)
{
out_id
[
bid
]
=
static_cast
<
int64_t
>
(
beam_max
[
j
].
id
);
out_val
[
bid
]
=
beam_max
[
j
].
v
;
break
;
}
}
}
else
{
out_id
[
bid
]
=
static_cast
<
int64_t
>
(
beam_max
[
i
].
id
);
out_val
[
bid
]
=
beam_max
[
i
].
v
;
out_val
[
bid
]
=
beam_max
[
i
].
v
;
#ifdef DEBUG_TOPP
}
printf
(
"bi: %d, early stop id: %d
\n
"
,
bid
,
static_cast
<
int
>
(
out_id
[
bid
]));
#endif
break
;
break
;
}
}
}
}
...
@@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
...
@@ -374,11 +383,14 @@ __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
struct
BlockPrefixCallbackOp
{
struct
BlockPrefixCallbackOp
{
// Running prefix
float
running_total
;
float
running_total
;
// Constructor
__device__
BlockPrefixCallbackOp
(
float
running_total
)
__device__
BlockPrefixCallbackOp
(
float
running_total
)
:
running_total
(
running_total
)
{}
:
running_total
(
running_total
)
{}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__
float
operator
()(
float
block_aggregate
)
{
__device__
float
operator
()(
float
block_aggregate
)
{
float
old_prefix
=
running_total
;
float
old_prefix
=
running_total
;
running_total
+=
block_aggregate
;
running_total
+=
block_aggregate
;
...
@@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp {
...
@@ -386,14 +398,28 @@ struct BlockPrefixCallbackOp {
}
}
};
};
template
<
typename
T
>
__device__
T
max_func
(
const
T
a
,
const
T
b
)
{
return
a
>
b
?
a
:
b
;
}
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max_func
(
a
,
b
);
}
};
template
<
typename
T
,
int
BLOCK_SIZE
>
template
<
typename
T
,
int
BLOCK_SIZE
>
__global__
void
topp_sampling
(
T
*
sorted_probs
,
__global__
void
topp_sampling
(
T
*
sorted_probs
,
int64_t
*
sorted_id
,
int64_t
*
sorted_id
,
T
*
out_val
,
T
*
out_val
,
int64_t
*
out_id
,
int64_t
*
out_id
,
const
T
*
top_ps
,
const
T
*
top_ps
,
int
p_num
,
const
T
*
threshold
,
int
vocab_size
,
const
uint64_t
seed
,
const
int
p_num
,
const
int
vocab_size
,
int
*
count_iter
,
int
*
count_iter
,
int
*
count_iter_begin
)
{
int
*
count_iter_begin
)
{
__shared__
int
stop_shared
;
__shared__
int
stop_shared
;
...
@@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs,
...
@@ -404,6 +430,8 @@ __global__ void topp_sampling(T* sorted_probs,
const
int
lane_id
=
tid
%
32
;
const
int
lane_id
=
tid
%
32
;
const
int
warp_id
=
tid
/
32
;
const
int
warp_id
=
tid
/
32
;
const
float
p_t
=
static_cast
<
float
>
(
top_ps
[
bid
]);
const
float
p_t
=
static_cast
<
float
>
(
top_ps
[
bid
]);
const
float
threshold_now
=
threshold
?
static_cast
<
float
>
(
threshold
[
bid
])
:
0.
f
;
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
stop_shared
=
0
;
stop_shared
=
0
;
rand_p
=
p_t
;
rand_p
=
p_t
;
...
@@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs,
...
@@ -417,8 +445,11 @@ __global__ void topp_sampling(T* sorted_probs,
}
}
typedef
cub
::
BlockScan
<
float
,
BLOCK_SIZE
>
BlockScan
;
typedef
cub
::
BlockScan
<
float
,
BLOCK_SIZE
>
BlockScan
;
typedef
cub
::
BlockReduce
<
int
,
BLOCK_SIZE
>
BlockReduce
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_reduce
;
__shared__
uint32_t
selected_shared
[
NUM_WARPS
];
__shared__
uint32_t
selected_shared
[
NUM_WARPS
];
int
threshold_id
=
0
;
// Initialize running total
// Initialize running total
BlockPrefixCallbackOp
prefix_op
(
0
);
BlockPrefixCallbackOp
prefix_op
(
0
);
...
@@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs,
...
@@ -429,23 +460,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads
();
__syncthreads
();
int
offset
=
bid
*
vocab_size
;
int
offset
=
bid
*
vocab_size
;
#ifdef DEBUG_TOPP
if
(
tid
==
0
)
{
printf
(
"first_elem1_1: %f, first_elem1_2: %f, first_id1_1: %d, first_id1_2: "
"%d
\n
"
,
static_cast
<
float
>
(
sorted_probs
[
offset
]),
static_cast
<
float
>
(
sorted_probs
[
offset
+
1
]),
static_cast
<
int
>
(
sorted_id
[
offset
]),
static_cast
<
int
>
(
sorted_id
[
offset
+
1
]);
}
#endif
int
end
=
((
vocab_size
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
end
=
((
vocab_size
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
i_activate
=
0
;
int
i_activate
=
0
;
float
thread_offset
=
0
;
float
thread_offset
=
0
;
for
(
int
i
=
tid
;
i
<
end
;
i
+=
BLOCK_SIZE
)
{
for
(
int
i
=
tid
;
i
<
end
;
i
+=
BLOCK_SIZE
)
{
float
thread_count
=
float
thread_count
=
(
i
<
vocab_size
)
?
static_cast
<
float
>
(
sorted_probs
[
offset
+
i
])
:
0.
f
;
(
i
<
vocab_size
)
?
static_cast
<
float
>
(
sorted_probs
[
offset
+
i
])
:
0.
f
;
if
(
i
<
vocab_size
&&
thread_count
>=
threshold_now
)
{
threshold_id
=
i
;
}
BlockScan
(
temp_storage
)
BlockScan
(
temp_storage
)
.
InclusiveSum
(
thread_count
,
thread_offset
,
prefix_op
);
.
InclusiveSum
(
thread_count
,
thread_offset
,
prefix_op
);
...
@@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs,
...
@@ -466,32 +489,15 @@ __global__ void topp_sampling(T* sorted_probs,
__syncthreads
();
__syncthreads
();
if
(
stop_shared
==
0
)
{
if
(
stop_shared
==
0
)
{
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
out_id
[
bid
]
=
sorted_id
[
offset
+
vocab_size
-
1
];
out_id
[
bid
]
=
sorted_id
[
offset
];
out_val
[
bid
]
=
sorted_probs
[
offset
+
vocab_size
-
1
];
out_val
[
bid
]
=
sorted_probs
[
offset
];
#ifdef DEBUG_TOPP
printf
(
"stop_shared: %d, out_id: %d, out_val: %f
\n
"
,
static_cast
<
int
>
(
stop_shared
),
static_cast
<
int
>
(
out_id
[
bid
]),
static_cast
<
float
>
(
out_val
[
bid
]);
#endif
}
}
return
;
return
;
}
}
#ifdef DEBUG_TOPP
if
(
tid
==
0
)
{
printf
(
"first_elem2_1: %f, first_elem2_2: %f, first_id2_1: %d, first_id2_2: "
"%d
\n
"
,
static_cast
<
float
>
(
sorted_probs
[
offset
]),
static_cast
<
float
>
(
sorted_probs
[
offset
+
1
]),
static_cast
<
int
>
(
sorted_id
[
offset
]),
static_cast
<
int
>
(
sorted_id
[
offset
+
1
]);
}
#endif
bool
skip
=
(
selected_shared
[
warp_id
]
>
0
)
?
false
:
true
;
bool
skip
=
(
selected_shared
[
warp_id
]
>
0
)
?
false
:
true
;
for
(
int
i
=
0
;
i
<
warp_id
;
i
++
)
{
for
(
int
i
=
0
;
i
<
warp_id
;
i
++
)
{
if
(
selected_shared
[
i
]
!=
0
)
{
if
(
selected_shared
[
i
]
!=
0
)
{
// If the previous has stopped, skip the current warp
skip
=
true
;
skip
=
true
;
}
}
}
}
...
@@ -499,19 +505,22 @@ __global__ void topp_sampling(T* sorted_probs,
...
@@ -499,19 +505,22 @@ __global__ void topp_sampling(T* sorted_probs,
int
active_lane_id
=
int
active_lane_id
=
WARP_SIZE
-
__popc
(
selected_shared
[
warp_id
]);
// first not 0
WARP_SIZE
-
__popc
(
selected_shared
[
warp_id
]);
// first not 0
if
(
lane_id
==
active_lane_id
)
{
if
(
lane_id
==
active_lane_id
)
{
#ifdef DEBUG_TOPP
float
val
=
static_cast
<
float
>
(
sorted_probs
[
offset
+
i_activate
]);
printf
(
if
(
val
<
threshold_now
)
{
"active_lane_id: %d, i_activate: %d.
\n
"
,
active_lane_id
,
i_activate
);
// don't sample low score token
for
(
int
i
=
0
;
i
<
active_lane_id
;
i
++
)
{
int
max_id
=
printf
(
"p %d, value: %f
\n
"
,
BlockReduce
(
temp_storage_reduce
).
Reduce
(
threshold_id
,
MaxOp
<
int
>
());
i
,
curandStatePhilox4_32_10_t
rng
;
static_cast
<
float
>
(
sorted_probs
[
offset
+
i
]));
curand_init
(
seed
,
tid
,
0
,
&
rng
);
}
int
random_id
=
curand
(
&
rng
)
%
(
max_id
+
1
);
#endif
out_id
[
bid
]
=
sorted_id
[
offset
+
random_id
];
out_val
[
bid
]
=
sorted_probs
[
offset
+
random_id
];
}
else
{
out_id
[
bid
]
=
sorted_id
[
offset
+
i_activate
];
out_id
[
bid
]
=
sorted_id
[
offset
+
i_activate
];
out_val
[
bid
]
=
sorted_probs
[
offset
+
i_activate
];
out_val
[
bid
]
=
sorted_probs
[
offset
+
i_activate
];
}
}
}
}
}
}
}
int
GetBlockSize
(
int
vocab_size
)
{
int
GetBlockSize
(
int
vocab_size
)
{
...
@@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) {
...
@@ -544,10 +553,26 @@ __global__ void print_kernel(T* input, int size) {
}
}
}
}
template
<
typename
T
>
T
*
SafeGetTensorPtr
(
const
DenseTensor
&
t
)
{
return
const_cast
<
T
*>
(
t
.
data
<
T
>
());
}
template
<
typename
T
>
T
*
SafeGetTensorPtr
(
const
DenseTensor
*
t
)
{
return
t
?
SafeGetTensorPtr
<
T
>
(
*
t
)
:
nullptr
;
}
template
<
typename
T
>
T
*
SafeGetTensorPtr
(
const
paddle
::
optional
<
DenseTensor
>&
t
)
{
return
t
?
SafeGetTensorPtr
<
T
>
(
t
.
get
())
:
nullptr
;
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
TopPSamplingKernel
(
const
Context
&
dev_ctx
,
void
TopPSamplingKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
ps
,
const
DenseTensor
&
ps
,
const
paddle
::
optional
<
DenseTensor
>&
threshold
,
int
random_seed
,
int
random_seed
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
ids
)
{
DenseTensor
*
ids
)
{
...
@@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx,
...
@@ -597,11 +622,12 @@ void TopPSamplingKernel(const Context& dev_ctx,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx
.
stream
())));
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx
.
stream
())));
dev_curand_states
=
dev_curand_states
=
reinterpret_cast
<
curandState_t
*>
(
curand_states_buf
->
ptr
());
reinterpret_cast
<
curandState_t
*>
(
curand_states_buf
->
ptr
());
unsigned
int
seed
=
0
;
if
(
random_seed
==
-
1
)
{
if
(
random_seed
==
-
1
)
{
srand
((
unsigned
int
)(
time
(
NULL
))
);
rand_r
(
&
seed
);
setup_kernel
<<<
1
,
256
,
0
,
cu_stream
>>>
(
dev_curand_states
,
rand
()
,
bs
);
setup_kernel
<<<
1
,
256
,
0
,
cu_stream
>>>
(
dev_curand_states
,
seed
,
bs
);
}
else
{
}
else
{
se
tup_kernel
<<<
1
,
256
,
0
,
cu_stream
>>>
(
dev_curand_states
,
random_seed
,
bs
)
;
se
ed
=
random_seed
;
}
}
DenseTensor
count_iter
;
DenseTensor
count_iter
;
...
@@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx,
...
@@ -612,12 +638,15 @@ void TopPSamplingKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
int
>(
&
count_iter_begin
);
dev_ctx
.
template
Alloc
<
int
>(
&
count_iter_begin
);
SetCountIter
<<<
1
,
256
,
0
,
cu_stream
>>>
(
count_iter
.
data
<
int
>
(),
bs
+
1
);
SetCountIter
<<<
1
,
256
,
0
,
cu_stream
>>>
(
count_iter
.
data
<
int
>
(),
bs
+
1
);
T
*
threshold_data
=
SafeGetTensorPtr
<
T
>
(
threshold
);
constexpr
int
TopKMaxLength
=
2
;
constexpr
int
TopKMaxLength
=
2
;
constexpr
int
TopPBeamTopK
=
10
;
constexpr
int
TopPBeamTopK
=
10
;
switch
(
BlockSize
)
{
switch
(
BlockSize
)
{
FIXED_BLOCK_DIM
(
FIXED_BLOCK_DIM
(
KeMatrixTopPBeamTopK
<
T
,
TopKMaxLength
,
TopPBeamTopK
,
kBlockDim
>
KeMatrixTopPBeamTopK
<
T
,
TopKMaxLength
,
TopPBeamTopK
,
kBlockDim
>
<<<
bs
,
kBlockDim
,
0
,
cu_stream
>>>
(
x
.
data
<
T
>
(),
<<<
bs
,
kBlockDim
,
0
,
cu_stream
>>>
(
x
.
data
<
T
>
(),
threshold_data
,
ps_now
.
data
<
T
>
(),
ps_now
.
data
<
T
>
(),
ids_ptr
,
ids_ptr
,
out_ptr
,
out_ptr
,
...
@@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx,
...
@@ -682,6 +711,8 @@ void TopPSamplingKernel(const Context& dev_ctx,
out_ptr
,
out_ptr
,
ids_ptr
,
ids_ptr
,
ps_now
.
data
<
T
>
(),
ps_now
.
data
<
T
>
(),
threshold_data
,
seed
,
p_num
,
p_num
,
vocab_size
,
vocab_size
,
count_iter
.
data
<
int
>
(),
count_iter
.
data
<
int
>
(),
...
...
paddle/phi/kernels/top_p_sampling_kernel.h
浏览文件 @
e1545af4
...
@@ -22,6 +22,7 @@ template <typename T, typename Context>
...
@@ -22,6 +22,7 @@ template <typename T, typename Context>
void
TopPSamplingKernel
(
const
Context
&
dev_ctx
,
void
TopPSamplingKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
ps
,
const
DenseTensor
&
ps
,
const
paddle
::
optional
<
DenseTensor
>&
threshold
,
int
random_seed
,
int
random_seed
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
ids
);
DenseTensor
*
ids
);
...
...
python/paddle/fluid/tests/unittests/test_top_p_sampling.py
浏览文件 @
e1545af4
...
@@ -53,6 +53,9 @@ def TopPProcess(probs, top_p):
...
@@ -53,6 +53,9 @@ def TopPProcess(probs, top_p):
return
next_scores
,
next_tokens
return
next_scores
,
next_tokens
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA "
)
class
TestTopPAPI
(
unittest
.
TestCase
):
class
TestTopPAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
topp
=
0.0
self
.
topp
=
0.0
...
@@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase):
...
@@ -74,7 +77,7 @@ class TestTopPAPI(unittest.TestCase):
).
reshape
((
-
1
,
1
))
).
reshape
((
-
1
,
1
))
# test case for basic test case 1
# test case for basic test case 1
paddle_result
=
paddle
.
top_p_sampling
(
paddle_result
=
paddle
.
top_p_sampling
(
input_tensor
,
topp_tensor
,
self
.
seed
input_tensor
,
topp_tensor
,
se
ed
=
se
lf
.
seed
)
)
ref_res
=
TopPProcess
(
input_tensor
,
self
.
topp
)
ref_res
=
TopPProcess
(
input_tensor
,
self
.
topp
)
...
@@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase):
...
@@ -98,7 +101,9 @@ class TestTopPAPI(unittest.TestCase):
topp_tensor
=
paddle
.
static
.
data
(
topp_tensor
=
paddle
.
static
.
data
(
name
=
"topp"
,
shape
=
[
6
,
1
],
dtype
=
self
.
dtype
name
=
"topp"
,
shape
=
[
6
,
1
],
dtype
=
self
.
dtype
)
)
result
=
paddle
.
top_p_sampling
(
input_tensor
,
topp_tensor
,
self
.
seed
)
result
=
paddle
.
top_p_sampling
(
input_tensor
,
topp_tensor
,
seed
=
self
.
seed
)
ref_res
=
TopPProcess
(
input_tensor
,
self
.
topp
)
ref_res
=
TopPProcess
(
input_tensor
,
self
.
topp
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
input_data
=
np
.
random
.
rand
(
6
,
1030
).
astype
(
self
.
dtype
)
input_data
=
np
.
random
.
rand
(
6
,
1030
).
astype
(
self
.
dtype
)
...
...
python/paddle/tensor/search.py
浏览文件 @
e1545af4
...
@@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
...
@@ -1131,13 +1131,14 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
return
values
,
indices
return
values
,
indices
def
top_p_sampling
(
x
,
ps
,
seed
=
None
,
name
=
None
):
def
top_p_sampling
(
x
,
ps
,
threshold
=
None
,
seed
=
None
,
name
=
None
):
"""
"""
Get the TopP scores and ids.
Get the TopP scores and ids.
Args:
Args:
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
threshold(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
seed(int, optional): the random seed,
seed(int, optional): the random seed,
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...
@@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None):
...
@@ -1149,10 +1150,10 @@ def top_p_sampling(x, ps, seed=None, name=None):
seed
=
-
1
seed
=
-
1
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
top_p_sampling
(
x
,
ps
,
seed
)
return
_C_ops
.
top_p_sampling
(
x
,
ps
,
threshold
,
seed
)
inputs
=
{
"x"
:
[
x
],
"ps"
:
[
ps
]
}
inputs
=
{
"x"
:
x
,
"ps"
:
ps
,
"threshold"
:
threshold
}
attrs
=
{
"seed"
:
seed
}
attrs
=
{
"
random_
seed"
:
seed
}
helper
=
LayerHelper
(
'top_p_sampling'
,
**
locals
())
helper
=
LayerHelper
(
'top_p_sampling'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
...
@@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None):
...
@@ -1160,7 +1161,7 @@ def top_p_sampling(x, ps, seed=None, name=None):
helper
.
append_op
(
helper
.
append_op
(
type
=
'top_p_sampling'
,
type
=
'top_p_sampling'
,
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
{
'out'
:
[
out
],
'ids'
:
[
ids
]
},
outputs
=
{
'out'
:
out
,
'ids'
:
ids
},
attrs
=
attrs
,
attrs
=
attrs
,
)
)
return
out
,
ids
return
out
,
ids
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录