未验证 提交 f9c9dc29 编写于 作者: L lzy 提交者: GitHub

add top_p_sampling (#54127)

上级 435560f0
......@@ -1921,6 +1921,15 @@
func : thresholded_relu
backward : thresholded_relu_grad
- op : top_p_sampling
args : (Tensor x, Tensor ps, int random_seed=-1)
output : Tensor (out), Tensor(ids)
infer_meta :
func : TopPSamplingInferMeta
kernel :
func : top_p_sampling
data_type : x
- op : topk
args : (Tensor x, Scalar(int) k = 1, int axis = -1, bool largest = true, bool sorted = true)
output : Tensor(out), Tensor(indices)
......
......@@ -2742,6 +2742,26 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y);
}
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
int random_seed,
MetaTensor* out,
MetaTensor* ids) {
auto x_dims = x.dims();
auto ps_dims = ps.dims();
PADDLE_ENFORCE_EQ(x_dims[0],
ps_dims[0],
phi::errors::InvalidArgument(
"The x_dims[0] must be equal to ps_dims[0] "
"But received x_dims[0] = %d and ps_dims[0] = %d.",
x_dims[0],
ps_dims[0]));
ids->set_dims(phi::make_ddim({x_dims[0], 1}));
ids->set_dtype(DataType::INT64);
out->set_dims(phi::make_ddim({x_dims[0], 1}));
out->set_dtype(x.dtype());
}
void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
......
......@@ -428,6 +428,12 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular,
MetaTensor* out);
void TopPSamplingInferMeta(const MetaTensor& x,
const MetaTensor& ps,
int random_seed,
MetaTensor* out,
MetaTensor* ids);
void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TopPSamplingKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& ps,
int random_seed,
DenseTensor* out,
DenseTensor* ids);
} // namespace phi
......@@ -320,6 +320,7 @@ from .tensor.search import nonzero # noqa: F401
from .tensor.search import sort # noqa: F401
from .tensor.search import kthvalue # noqa: F401
from .tensor.search import mode # noqa: F401
from .tensor.search import top_p_sampling # noqa: F401
from .tensor.to_string import set_printoptions # noqa: F401
......@@ -542,6 +543,7 @@ __all__ = [ # noqa
'zeros_like',
'maximum',
'topk',
'top_p_sampling',
'index_select',
'CPUPlace',
'matmul',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
def TopPProcess(probs, top_p):
sorted_probs = paddle.sort(probs, descending=True)
sorted_indices = paddle.argsort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
# Remove tokens with cumulative probs above the top_p, But keep at
# least min_tokens_to_keep tokens
sorted_indices_to_remove = cumulative_probs > top_p
# Keep the first token
sorted_indices_to_remove = paddle.cast(
sorted_indices_to_remove, dtype='int64'
)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# Scatter sorted tensors to original indexing
sorted_indices = (
sorted_indices
+ paddle.arange(probs.shape[0]).unsqueeze(-1) * probs.shape[-1]
)
condition = paddle.scatter(
sorted_indices_to_remove.flatten(),
sorted_indices.flatten(),
sorted_indices_to_remove.flatten(),
)
condition = paddle.cast(condition, 'bool').reshape(probs.shape)
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
next_tokens = paddle.multinomial(probs)
next_scores = paddle.index_sample(probs, next_tokens)
return next_scores, next_tokens
class TestTopPAPI(unittest.TestCase):
def setUp(self):
self.topp = 0.0
self.seed = 6688
self.batch_size = 3
self.vocab_size = 10000
self.dtype = "float32"
self.input_data = np.random.rand(self.batch_size, self.vocab_size)
def run_dygraph(self, place):
with paddle.fluid.dygraph.guard(place):
input_tensor = paddle.to_tensor(self.input_data, self.dtype)
topp_tensor = paddle.to_tensor(
[
self.topp,
]
* self.batch_size,
self.dtype,
).reshape((-1, 1))
# test case for basic test case 1
paddle_result = paddle.top_p_sampling(
input_tensor, topp_tensor, self.seed
)
ref_res = TopPProcess(input_tensor, self.topp)
np.testing.assert_allclose(
paddle_result[0].numpy(), ref_res[0].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
paddle_result[1].numpy().flatten(),
ref_res[1].numpy().flatten(),
rtol=0,
)
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input_tensor = paddle.static.data(
name="x", shape=[6, 1030], dtype=self.dtype
)
topp_tensor = paddle.static.data(
name="topp", shape=[6, 1], dtype=self.dtype
)
result = paddle.top_p_sampling(input_tensor, topp_tensor, self.seed)
ref_res = TopPProcess(input_tensor, self.topp)
exe = paddle.static.Executor(place)
input_data = np.random.rand(6, 1030).astype(self.dtype)
paddle_result = exe.run(
feed={
"x": input_data,
"topp": np.array(
[
self.topp,
]
* 6
).astype(self.dtype),
},
fetch_list=[
result[0],
result[1],
ref_res[0],
ref_res[1],
],
)
np.testing.assert_allclose(
paddle_result[0], paddle_result[2], rtol=1e-05
)
np.testing.assert_allclose(
paddle_result[1], paddle_result[3], rtol=1e-05
)
def test_cases(self):
places = [core.CUDAPlace(0)]
for place in places:
self.run_dygraph(place)
self.run_static(place)
if __name__ == "__main__":
unittest.main()
......@@ -278,6 +278,7 @@ from .search import index_sample # noqa: F401
from .search import masked_select # noqa: F401
from .search import kthvalue # noqa: F401
from .search import mode # noqa: F401
from .search import top_p_sampling
from .stat import mean # noqa: F401
from .stat import std # noqa: F401
......@@ -468,6 +469,7 @@ tensor_method_func = [ # noqa
'argsort',
'masked_select',
'topk',
'top_p_sampling',
'where',
'index_select',
'nonzero',
......
......@@ -1129,3 +1129,38 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
)
indices.stop_gradient = True
return values, indices
def top_p_sampling(x, ps, seed=None, name=None):
"""
Get the TopP scores and ids.
Args:
x(Tensor): A N-D Tensor with type float32, float16 and bfloat16.
ps(Tensor): A 1-D Tensor with type float32, float16 and bfloat16.
seed(int, optional): the random seed,
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
"""
if seed is None:
seed = -1
if in_dygraph_mode():
return _C_ops.top_p_sampling(x, ps, seed)
inputs = {"x": [x], "ps": [ps]}
attrs = {"seed": seed}
helper = LayerHelper('top_p_sampling', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
ids = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type='top_p_sampling',
inputs=inputs,
outputs={'out': [out], 'ids': [ids]},
attrs=attrs,
)
return out, ids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册