From 662230487cdaac10891d5e272175a81a4e234cb3 Mon Sep 17 00:00:00 2001 From: Yanxing Shi <48111042+Yanxing-Shi@users.noreply.github.com> Date: Mon, 13 Sep 2021 12:06:36 +0800 Subject: [PATCH] Add searchsorted op (#35159) * fix github name * fix CI error * fix review and CI error * fix inf,nan error and modify unittest samples * add unittest samples * add unittest samples * fix unittest error * test=document_fix * test=document_fix * modify doc and add unittest samples * fix error newline in constant * modify doc after mentor review * modify __all__ and doc * modify doc --- paddle/fluid/operators/math/algorithm.h | 10 +- paddle/fluid/operators/searchsorted_op.cc | 126 +++++++++++ paddle/fluid/operators/searchsorted_op.cu | 23 ++ paddle/fluid/operators/searchsorted_op.h | 170 +++++++++++++++ python/paddle/__init__.py | 2 + .../tests/unittests/test_searchsorted_op.py | 198 ++++++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/search.py | 72 +++++++ 8 files changed, 597 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/searchsorted_op.cc create mode 100644 paddle/fluid/operators/searchsorted_op.cu create mode 100644 paddle/fluid/operators/searchsorted_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_searchsorted_op.py diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h index 864cb94cec..346c693a22 100644 --- a/paddle/fluid/operators/math/algorithm.h +++ b/paddle/fluid/operators/math/algorithm.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,8 +39,8 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { return -1; } -template -HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { +template +HOSTDEVICE inline size_t LowerBound(const T1 *x, size_t num, const T2 &val) { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound // The following code is from // https://en.cppreference.com/w/cpp/algorithm/lower_bound @@ -62,8 +62,8 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { #endif // @} End Group LowerBound } -template -HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) { +template +HOSTDEVICE inline size_t UpperBound(const T1 *x, size_t num, const T2 &val) { #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound // The following code is from // https://en.cppreference.com/w/cpp/algorithm/upper_bound diff --git a/paddle/fluid/operators/searchsorted_op.cc b/paddle/fluid/operators/searchsorted_op.cc new file mode 100644 index 0000000000..bbd5b9c4e7 --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/searchsorted_op.h" + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +class SearchSortedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + static bool SearchsortedDimsMatchedBeforeLastDim( + const framework::DDim& sequences_dims, + const framework::DDim& values_dims) { + if (sequences_dims.size() != values_dims.size()) { + return false; + } + const auto& sequences_dims_size = sequences_dims.size(); + for (int64_t dim = 0; dim < sequences_dims_size - 1; ++dim) { + if (sequences_dims[dim] != values_dims[dim]) { + return false; + } + } + return true; + } + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("SortedSequence"), "Input", "SortedSequence", + "searchsorted"); + OP_INOUT_CHECK(ctx->HasInput("Values"), "Input", "Values", "searchsorted"); + + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "searchsorted"); + + auto sequences_dims = ctx->GetInputDim("SortedSequence"); + auto values_dims = ctx->GetInputDim("Values"); + auto out_int32 = ctx->Attrs().Get("out_int32"); + + if (sequences_dims.size() != 1) { + PADDLE_ENFORCE_EQ( + SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims), + true, + platform::errors::Unavailable( + "The dimensions of sorted_sequence tensor ( %s ) and values " + "tensor ( %s ) can not match. Because the input sorted_sequence " + "tensor must be 1 dimension or the first N-1 dimensions of " + "sorted_sequence tensor and input values tensor must match. " + "Please input appropriate sorted_sequence and values again! ", + sequences_dims, values_dims)); + } + + if (out_int32) { + PADDLE_ENFORCE_LT( + sequences_dims[sequences_dims.size() - 1], + std::numeric_limits::max(), + platform::errors::Unavailable( + "The size of sorted_sequence %d exceed the maximum limit d%. " + "Because the size of sorted_sequence should be less than the " + "output maximum value for int32 bit. Please set appropriate " + "sorted_sequence to meet this requirement! ", + sequences_dims[sequences_dims.size() - 1], + std::numeric_limits::max())); + } + + ctx->SetOutputDim("Out", values_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "SortedSequence"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("SortedSequence", + "(Tensor), N-D or 1-D tensor, The value of the tensor" + "monotonically increases in the innermost dimension."); + AddInput("Values", "(Tensor), N-D tensor given values."); + AddOutput("Out", "(Tensor), The output tensor of searchsorted op."); + AddAttr("out_int32", + "the output tensor is int64 type if False and on the" + "contrary for int32") + .SetDefault(false); + AddAttr( + "right", + "corresponding to lower bound if False and upper bound if True") + .SetDefault(false); + + AddComment(R"DOC( + Searchsorted Operator. + + This OP is used to find the index of the corresponding sorted_sequence in the innermost dimension based on the given values. + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker); + +REGISTER_OP_CPU_KERNEL( + searchsorted, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel); diff --git a/paddle/fluid/operators/searchsorted_op.cu b/paddle/fluid/operators/searchsorted_op.cu new file mode 100644 index 0000000000..4633ab43ef --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/searchsorted_op.h" +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + searchsorted, ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel, + ops::SearchSortedKernel); diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h new file mode 100644 index 0000000000..5ae0e79907 --- /dev/null +++ b/paddle/fluid/operators/searchsorted_op.h @@ -0,0 +1,170 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class GpuAndCpuSearchSortedCompute { + public: + static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); } + static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); } + static HOSTDEVICE bool IsNan(int x) { return false; } + static HOSTDEVICE bool IsNan(int64_t x) { return false; } + + static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); } + static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); } + static HOSTDEVICE bool IsInf(int x) { return false; } + static HOSTDEVICE bool IsInf(int64_t x) { return false; } + + HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data, + const T2* value_data, bool right, + bool is_1d_boundaries, + int64_t val_size, int64_t seq_size, + OutType* out_data) + : sequence_data_(sequence_data), + value_data_(value_data), + right_(right), + is_1d_boundaries_(is_1d_boundaries), + val_size_(val_size), + seq_size_(seq_size), + out_data_(out_data) {} + HOSTDEVICE void operator()(int64_t idx) { + const T2* value_ptr = value_data_ + idx; + const T1* sequence_ptr = is_1d_boundaries_ + ? sequence_data_ + : sequence_data_ + idx / val_size_ * seq_size_; + if (IsInf(*value_ptr) || IsNan(*value_ptr)) { + out_data_[idx] = seq_size_; + } else { + if (right_) { + out_data_[idx] = static_cast( + math::UpperBound(sequence_ptr, seq_size_, *value_ptr)); + } else { + out_data_[idx] = static_cast( + math::LowerBound(sequence_ptr, seq_size_, *value_ptr)); + } + } + } + + private: + const T1* sequence_data_; + const T2* value_data_; + bool right_; + bool is_1d_boundaries_; + int64_t val_size_; + int64_t seq_size_; + OutType* out_data_; +}; + +template +class SearchSortedFunctor { + public: + SearchSortedFunctor(const framework::ExecutionContext& context, + const framework::Tensor* sorted_sequence, + const framework::Tensor* value, bool right, + OutType* out_data) + : context_(context), + sorted_sequence_(sorted_sequence), + value_(value), + right_(right), + out_data_(out_data) {} + + template + void apply() { + const T1* sequence_data = sorted_sequence_->data(); + const T2* value_data = value_->data(); + const framework::DDim& seq_dims = sorted_sequence_->dims(); + const framework::DDim& val_dims = value_->dims(); + + bool is_1d_boundaries = seq_dims.size() == 1; + int64_t val_size = val_dims[val_dims.size() - 1]; + int64_t seq_size = seq_dims[seq_dims.size() - 1]; + + auto& dev_ctx = context_.template device_context(); + platform::ForRange for_range(dev_ctx, value_->numel()); + GpuAndCpuSearchSortedCompute + gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_, + is_1d_boundaries, val_size, seq_size, + out_data_); + for_range(gpu_and_cpu_search_sorted_compute); + } + + private: + const framework::ExecutionContext& context_; + const framework::Tensor* sorted_sequence_; + const framework::Tensor* value_; + bool right_; + OutType* out_data_; +}; + +template +static void VisitDataType(framework::proto::VarType::Type type, + Visitor visitor) { + if (type == framework::proto::VarType::FP32) { + visitor.template apply(); + } else if (type == framework::proto::VarType::FP64) { + visitor.template apply(); + } else if (type == framework::proto::VarType::INT32) { + visitor.template apply(); + } else if (type == framework::proto::VarType::INT64) { + visitor.template apply(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The recieved values data type %s can not meet input requirements. " + "Because the given values data type of searchsorted operators must be " + "float32, float64, int32 or int64. Please input appropriate " + "sorted_sequence again! ", + framework::DataTypeToString(type))); + } +} + +template +class SearchSortedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* sorted_sequence = context.Input("SortedSequence"); + auto* value = context.Input("Values"); + bool out_int32 = context.Attr("out_int32"); + bool right = context.Attr("right"); + auto* out = context.Output("Out"); + + if (out_int32) { + int* out_data = out->mutable_data(context.GetPlace()); + SearchSortedFunctor functor( + context, sorted_sequence, value, right, out_data); + VisitDataType(value->type(), functor); + } else { + int64_t* out_data = out->mutable_data(context.GetPlace()); + SearchSortedFunctor functor( + context, sorted_sequence, value, right, out_data); + VisitDataType(value->type(), functor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9d60a5b381..d0b705cde6 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -233,6 +233,7 @@ from .tensor.random import randperm # noqa: F401 from .tensor.search import argmax # noqa: F401 from .tensor.search import argmin # noqa: F401 from .tensor.search import argsort # noqa: F401 +from .tensor.search import searchsorted # noqa: F401 from .tensor.search import masked_select # noqa: F401 from .tensor.search import topk # noqa: F401 from .tensor.search import where # noqa: F401 @@ -358,6 +359,7 @@ __all__ = [ # noqa 'summary', 'flops', 'sort', + 'searchsorted', 'split', 'logical_and', 'full_like', diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py new file mode 100644 index 0000000000..f595d06d5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +paddle.enable_static() +from op_test import OpTest + + +class TestSearchSorted(OpTest): + def setUp(self): + + self.op_type = "searchsorted" + self.init_test_case() + + self.inputs = { + 'SortedSequence': self.sorted_sequence, + 'Values': self.values + } + self.attrs = {"out_int32": False, "right": False} + self.attrs["right"] = True if self.side == 'right' else False + self.outputs = { + 'Out': np.searchsorted( + self.sorted_sequence, self.values, side=self.side) + } + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float32") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float32") + self.side = "left" + + +class TestSearchSortedOp1(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int32") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int32") + self.side = "right" + + +class TestSearchSortedOp2(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int64") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int64") + self.side = "left" + + +class TestSearchSortedOp3(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") + self.values = np.array( + [[np.nan, np.nan, np.nan], [3, 6, 9]]).astype("float64") + self.side = "left" + + +class TestSearchSortedOp4(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") + self.values = np.array( + [[np.inf, np.inf, np.inf], [3, 6, 9]]).astype("float64") + self.side = "right" + + +class TestSearchSortedOp5(TestSearchSorted): + def init_test_case(self): + self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64") + self.values = np.array([[np.inf, np.inf, np.inf], + [np.nan, np.nan, np.nan]]).astype("float64") + self.side = "right" + + +class TestSearchSortedAPI(unittest.TestCase): + def init_test_case(self): + self.sorted_sequence = np.array([2, 4, 6, 8, 10]).astype("float64") + self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float64") + + def setUp(self): + self.init_test_case() + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', + shape=self.sorted_sequence.shape, + dtype="float64") + values = paddle.static.data( + 'Values', shape=self.values.shape, dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + exe = paddle.static.Executor(place) + res = exe.run(feed={ + 'SortedSequence': self.sorted_sequence, + 'Values': self.values + }, + fetch_list=out) + out_ref = np.searchsorted(self.sorted_sequence, self.values) + self.assertTrue(np.allclose(out_ref, res)) + + for place in self.place: + run(place) + + def test_dygraph_api(self): + def run(place): + + paddle.disable_static(place) + sorted_sequence = paddle.to_tensor(self.sorted_sequence) + values = paddle.to_tensor(self.values) + out = paddle.searchsorted(sorted_sequence, values, right=True) + out_ref = np.searchsorted( + self.sorted_sequence, self.values, side='right') + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_out_int32(self): + paddle.disable_static() + sorted_sequence = paddle.to_tensor(self.sorted_sequence) + values = paddle.to_tensor(self.values) + out = paddle.searchsorted(sorted_sequence, values, out_int32=True) + self.assertTrue(out.type, 'int32') + + +class TestSearchSortedError(unittest.TestCase): + def test_error_api(self): + paddle.enable_static() + + def test_searchsorted_dims_matched_before_lastdim_error1(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, 3], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 5], dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(RuntimeError, + test_searchsorted_dims_matched_before_lastdim_error1) + + def test_searchsorted_dims_matched_before_lastdim_error2(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, 3], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 3, 5], dtype="float64") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(RuntimeError, + test_searchsorted_dims_matched_before_lastdim_error2) + + def test_searchsorted_sortedsequence_size_error(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 2, pow(2, 34)], dtype="float64") + values = paddle.static.data( + 'Values', shape=[2, 2, 5], dtype="float64") + out = paddle.searchsorted( + sorted_sequence, values, out_int32=True) + + self.assertRaises(RuntimeError, + test_searchsorted_sortedsequence_size_error) + + def test_sortedsequence_values_type_error(): + with paddle.static.program_guard(paddle.static.Program()): + sorted_sequence = paddle.static.data( + 'SortedSequence', shape=[2, 3], dtype="int16") + values = paddle.static.data( + 'Values', shape=[2, 5], dtype="int16") + out = paddle.searchsorted(sorted_sequence, values) + + self.assertRaises(TypeError, test_sortedsequence_values_type_error) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b9e0c75a60..73369a6e8e 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -190,6 +190,7 @@ from .random import randperm # noqa: F401 from .search import argmax # noqa: F401 from .search import argmin # noqa: F401 from .search import argsort # noqa: F401 +from .search import searchsorted # noqa: F401 from .search import topk # noqa: F401 from .search import where # noqa: F401 from .search import index_select # noqa: F401 diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 79eeae78a4..f3587aa48d 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -766,3 +766,75 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): attrs=attrs) indices.stop_gradient = True return values, indices + + +def searchsorted(sorted_sequence, + values, + out_int32=False, + right=False, + name=None): + """ + This OP is used to find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`. + + Args: + sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension. + values(Tensor): An input N-D tensor value with type int32, int64, float32, float64. + out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64. + right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension. + The default value is False and it shows the lower bounds. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor(the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64. + + Examples: + + .. code-block:: python + + import paddle + + sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11], + [2, 4, 6, 8, 10, 12]], dtype='int32') + values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32') + out1 = paddle.searchsorted(sorted_sequence, values) + print(out1) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4, 5], + # [1, 2, 4, 4]]) + out2 = paddle.searchsorted(sorted_sequence, values, right=True) + print(out2) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[2, 3, 5, 5], + # [1, 3, 4, 5]]) + sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13]) + out3 = paddle.searchsorted(sorted_sequence_1d, values) + print(out3) + # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 3, 4, 5], + # [1, 3, 4, 5]]) + + """ + + if in_dygraph_mode(): + return _C_ops.searchsorted(sorted_sequence, values, "out_int32", + out_int32, "right", right) + + check_variable_and_dtype(sorted_sequence, 'SortedSequence', + ['float32', 'float64', 'int32', 'int64'], + 'paddle.searchsorted') + check_variable_and_dtype(values, 'Values', + ['float32', 'float64', 'int32', 'int64'], + 'paddle.searchsorted') + + helper = LayerHelper('searchsorted', **locals()) + out_type = 'int32' if out_int32 else 'int64' + out = helper.create_variable_for_type_inference(dtype=out_type) + helper.append_op( + type='searchsorted', + inputs={'SortedSequence': sorted_sequence, + "Values": values}, + outputs={'Out': out}, + attrs={"out_int32": out_int32, + "right": right}) + + return out -- GitLab