未验证 提交 66223048 编写于 作者: Y Yanxing Shi 提交者: GitHub

Add searchsorted op (#35159)

* fix github name

* fix CI error

* fix review and CI error

* fix inf,nan error and modify unittest samples

* add unittest samples

* add unittest samples

* fix unittest error

* test=document_fix

* test=document_fix

* modify doc and add unittest samples

* fix error newline in constant

* modify doc after mentor review

* modify __all__ and doc

* modify doc
上级 2bb44317
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -39,8 +39,8 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { ...@@ -39,8 +39,8 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
return -1; return -1;
} }
template <typename T> template <typename T1, typename T2>
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { HOSTDEVICE inline size_t LowerBound(const T1 *x, size_t num, const T2 &val) {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound
// The following code is from // The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound // https://en.cppreference.com/w/cpp/algorithm/lower_bound
...@@ -62,8 +62,8 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { ...@@ -62,8 +62,8 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
#endif // @} End Group LowerBound #endif // @} End Group LowerBound
} }
template <typename T> template <typename T1, typename T2>
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) { HOSTDEVICE inline size_t UpperBound(const T1 *x, size_t num, const T2 &val) {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound #if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound
// The following code is from // The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound // https://en.cppreference.com/w/cpp/algorithm/upper_bound
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
class SearchSortedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
static bool SearchsortedDimsMatchedBeforeLastDim(
const framework::DDim& sequences_dims,
const framework::DDim& values_dims) {
if (sequences_dims.size() != values_dims.size()) {
return false;
}
const auto& sequences_dims_size = sequences_dims.size();
for (int64_t dim = 0; dim < sequences_dims_size - 1; ++dim) {
if (sequences_dims[dim] != values_dims[dim]) {
return false;
}
}
return true;
}
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("SortedSequence"), "Input", "SortedSequence",
"searchsorted");
OP_INOUT_CHECK(ctx->HasInput("Values"), "Input", "Values", "searchsorted");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "searchsorted");
auto sequences_dims = ctx->GetInputDim("SortedSequence");
auto values_dims = ctx->GetInputDim("Values");
auto out_int32 = ctx->Attrs().Get<bool>("out_int32");
if (sequences_dims.size() != 1) {
PADDLE_ENFORCE_EQ(
SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims),
true,
platform::errors::Unavailable(
"The dimensions of sorted_sequence tensor ( %s ) and values "
"tensor ( %s ) can not match. Because the input sorted_sequence "
"tensor must be 1 dimension or the first N-1 dimensions of "
"sorted_sequence tensor and input values tensor must match. "
"Please input appropriate sorted_sequence and values again! ",
sequences_dims, values_dims));
}
if (out_int32) {
PADDLE_ENFORCE_LT(
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max(),
platform::errors::Unavailable(
"The size of sorted_sequence %d exceed the maximum limit d%. "
"Because the size of sorted_sequence should be less than the "
"output maximum value for int32 bit. Please set appropriate "
"sorted_sequence to meet this requirement! ",
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max()));
}
ctx->SetOutputDim("Out", values_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SortedSequence");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("SortedSequence",
"(Tensor), N-D or 1-D tensor, The value of the tensor"
"monotonically increases in the innermost dimension.");
AddInput("Values", "(Tensor), N-D tensor given values.");
AddOutput("Out", "(Tensor), The output tensor of searchsorted op.");
AddAttr<bool>("out_int32",
"the output tensor is int64 type if False and on the"
"contrary for int32")
.SetDefault(false);
AddAttr<bool>(
"right",
"corresponding to lower bound if False and upper bound if True")
.SetDefault(false);
AddComment(R"DOC(
Searchsorted Operator.
This OP is used to find the index of the corresponding sorted_sequence in the innermost dimension based on the given values.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker);
REGISTER_OP_CPU_KERNEL(
searchsorted,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, float>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, double>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
searchsorted, ops::SearchSortedKernel<plat::CUDADeviceContext, float>,
ops::SearchSortedKernel<plat::CUDADeviceContext, double>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute {
public:
static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(int x) { return false; }
static HOSTDEVICE bool IsNan(int64_t x) { return false; }
static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(int x) { return false; }
static HOSTDEVICE bool IsInf(int64_t x) { return false; }
HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data,
const T2* value_data, bool right,
bool is_1d_boundaries,
int64_t val_size, int64_t seq_size,
OutType* out_data)
: sequence_data_(sequence_data),
value_data_(value_data),
right_(right),
is_1d_boundaries_(is_1d_boundaries),
val_size_(val_size),
seq_size_(seq_size),
out_data_(out_data) {}
HOSTDEVICE void operator()(int64_t idx) {
const T2* value_ptr = value_data_ + idx;
const T1* sequence_ptr = is_1d_boundaries_
? sequence_data_
: sequence_data_ + idx / val_size_ * seq_size_;
if (IsInf(*value_ptr) || IsNan(*value_ptr)) {
out_data_[idx] = seq_size_;
} else {
if (right_) {
out_data_[idx] = static_cast<OutType>(
math::UpperBound<T1, T2>(sequence_ptr, seq_size_, *value_ptr));
} else {
out_data_[idx] = static_cast<OutType>(
math::LowerBound<T1, T2>(sequence_ptr, seq_size_, *value_ptr));
}
}
}
private:
const T1* sequence_data_;
const T2* value_data_;
bool right_;
bool is_1d_boundaries_;
int64_t val_size_;
int64_t seq_size_;
OutType* out_data_;
};
template <typename DeviceContext, typename T1, typename OutType>
class SearchSortedFunctor {
public:
SearchSortedFunctor(const framework::ExecutionContext& context,
const framework::Tensor* sorted_sequence,
const framework::Tensor* value, bool right,
OutType* out_data)
: context_(context),
sorted_sequence_(sorted_sequence),
value_(value),
right_(right),
out_data_(out_data) {}
template <typename T2>
void apply() {
const T1* sequence_data = sorted_sequence_->data<T1>();
const T2* value_data = value_->data<T2>();
const framework::DDim& seq_dims = sorted_sequence_->dims();
const framework::DDim& val_dims = value_->dims();
bool is_1d_boundaries = seq_dims.size() == 1;
int64_t val_size = val_dims[val_dims.size() - 1];
int64_t seq_size = seq_dims[seq_dims.size() - 1];
auto& dev_ctx = context_.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, value_->numel());
GpuAndCpuSearchSortedCompute<T1, T2, OutType>
gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_,
is_1d_boundaries, val_size, seq_size,
out_data_);
for_range(gpu_and_cpu_search_sorted_compute);
}
private:
const framework::ExecutionContext& context_;
const framework::Tensor* sorted_sequence_;
const framework::Tensor* value_;
bool right_;
OutType* out_data_;
};
template <typename Visitor>
static void VisitDataType(framework::proto::VarType::Type type,
Visitor visitor) {
if (type == framework::proto::VarType::FP32) {
visitor.template apply<float>();
} else if (type == framework::proto::VarType::FP64) {
visitor.template apply<double>();
} else if (type == framework::proto::VarType::INT32) {
visitor.template apply<int>();
} else if (type == framework::proto::VarType::INT64) {
visitor.template apply<int64_t>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The recieved values data type %s can not meet input requirements. "
"Because the given values data type of searchsorted operators must be "
"float32, float64, int32 or int64. Please input appropriate "
"sorted_sequence again! ",
framework::DataTypeToString(type)));
}
}
template <typename DeviceContext, typename T>
class SearchSortedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* sorted_sequence = context.Input<Tensor>("SortedSequence");
auto* value = context.Input<Tensor>("Values");
bool out_int32 = context.Attr<bool>("out_int32");
bool right = context.Attr<bool>("right");
auto* out = context.Output<Tensor>("Out");
if (out_int32) {
int* out_data = out->mutable_data<int>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(value->type(), functor);
} else {
int64_t* out_data = out->mutable_data<int64_t>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int64_t> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(value->type(), functor);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -233,6 +233,7 @@ from .tensor.random import randperm # noqa: F401 ...@@ -233,6 +233,7 @@ from .tensor.random import randperm # noqa: F401
from .tensor.search import argmax # noqa: F401 from .tensor.search import argmax # noqa: F401
from .tensor.search import argmin # noqa: F401 from .tensor.search import argmin # noqa: F401
from .tensor.search import argsort # noqa: F401 from .tensor.search import argsort # noqa: F401
from .tensor.search import searchsorted # noqa: F401
from .tensor.search import masked_select # noqa: F401 from .tensor.search import masked_select # noqa: F401
from .tensor.search import topk # noqa: F401 from .tensor.search import topk # noqa: F401
from .tensor.search import where # noqa: F401 from .tensor.search import where # noqa: F401
...@@ -358,6 +359,7 @@ __all__ = [ # noqa ...@@ -358,6 +359,7 @@ __all__ = [ # noqa
'summary', 'summary',
'flops', 'flops',
'sort', 'sort',
'searchsorted',
'split', 'split',
'logical_and', 'logical_and',
'full_like', 'full_like',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
paddle.enable_static()
from op_test import OpTest
class TestSearchSorted(OpTest):
def setUp(self):
self.op_type = "searchsorted"
self.init_test_case()
self.inputs = {
'SortedSequence': self.sorted_sequence,
'Values': self.values
}
self.attrs = {"out_int32": False, "right": False}
self.attrs["right"] = True if self.side == 'right' else False
self.outputs = {
'Out': np.searchsorted(
self.sorted_sequence, self.values, side=self.side)
}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float32")
self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float32")
self.side = "left"
class TestSearchSortedOp1(TestSearchSorted):
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int32")
self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int32")
self.side = "right"
class TestSearchSortedOp2(TestSearchSorted):
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("int64")
self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("int64")
self.side = "left"
class TestSearchSortedOp3(TestSearchSorted):
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64")
self.values = np.array(
[[np.nan, np.nan, np.nan], [3, 6, 9]]).astype("float64")
self.side = "left"
class TestSearchSortedOp4(TestSearchSorted):
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64")
self.values = np.array(
[[np.inf, np.inf, np.inf], [3, 6, 9]]).astype("float64")
self.side = "right"
class TestSearchSortedOp5(TestSearchSorted):
def init_test_case(self):
self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float64")
self.values = np.array([[np.inf, np.inf, np.inf],
[np.nan, np.nan, np.nan]]).astype("float64")
self.side = "right"
class TestSearchSortedAPI(unittest.TestCase):
def init_test_case(self):
self.sorted_sequence = np.array([2, 4, 6, 8, 10]).astype("float64")
self.values = np.array([[3, 6, 9], [3, 6, 9]]).astype("float64")
def setUp(self):
self.init_test_case()
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_static_api(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
sorted_sequence = paddle.static.data(
'SortedSequence',
shape=self.sorted_sequence.shape,
dtype="float64")
values = paddle.static.data(
'Values', shape=self.values.shape, dtype="float64")
out = paddle.searchsorted(sorted_sequence, values)
exe = paddle.static.Executor(place)
res = exe.run(feed={
'SortedSequence': self.sorted_sequence,
'Values': self.values
},
fetch_list=out)
out_ref = np.searchsorted(self.sorted_sequence, self.values)
self.assertTrue(np.allclose(out_ref, res))
for place in self.place:
run(place)
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
sorted_sequence = paddle.to_tensor(self.sorted_sequence)
values = paddle.to_tensor(self.values)
out = paddle.searchsorted(sorted_sequence, values, right=True)
out_ref = np.searchsorted(
self.sorted_sequence, self.values, side='right')
self.assertEqual(np.allclose(out_ref, out.numpy()), True)
paddle.enable_static()
for place in self.place:
run(place)
def test_out_int32(self):
paddle.disable_static()
sorted_sequence = paddle.to_tensor(self.sorted_sequence)
values = paddle.to_tensor(self.values)
out = paddle.searchsorted(sorted_sequence, values, out_int32=True)
self.assertTrue(out.type, 'int32')
class TestSearchSortedError(unittest.TestCase):
def test_error_api(self):
paddle.enable_static()
def test_searchsorted_dims_matched_before_lastdim_error1():
with paddle.static.program_guard(paddle.static.Program()):
sorted_sequence = paddle.static.data(
'SortedSequence', shape=[2, 2, 3], dtype="float64")
values = paddle.static.data(
'Values', shape=[2, 5], dtype="float64")
out = paddle.searchsorted(sorted_sequence, values)
self.assertRaises(RuntimeError,
test_searchsorted_dims_matched_before_lastdim_error1)
def test_searchsorted_dims_matched_before_lastdim_error2():
with paddle.static.program_guard(paddle.static.Program()):
sorted_sequence = paddle.static.data(
'SortedSequence', shape=[2, 2, 3], dtype="float64")
values = paddle.static.data(
'Values', shape=[2, 3, 5], dtype="float64")
out = paddle.searchsorted(sorted_sequence, values)
self.assertRaises(RuntimeError,
test_searchsorted_dims_matched_before_lastdim_error2)
def test_searchsorted_sortedsequence_size_error():
with paddle.static.program_guard(paddle.static.Program()):
sorted_sequence = paddle.static.data(
'SortedSequence', shape=[2, 2, pow(2, 34)], dtype="float64")
values = paddle.static.data(
'Values', shape=[2, 2, 5], dtype="float64")
out = paddle.searchsorted(
sorted_sequence, values, out_int32=True)
self.assertRaises(RuntimeError,
test_searchsorted_sortedsequence_size_error)
def test_sortedsequence_values_type_error():
with paddle.static.program_guard(paddle.static.Program()):
sorted_sequence = paddle.static.data(
'SortedSequence', shape=[2, 3], dtype="int16")
values = paddle.static.data(
'Values', shape=[2, 5], dtype="int16")
out = paddle.searchsorted(sorted_sequence, values)
self.assertRaises(TypeError, test_sortedsequence_values_type_error)
if __name__ == '__main__':
unittest.main()
...@@ -190,6 +190,7 @@ from .random import randperm # noqa: F401 ...@@ -190,6 +190,7 @@ from .random import randperm # noqa: F401
from .search import argmax # noqa: F401 from .search import argmax # noqa: F401
from .search import argmin # noqa: F401 from .search import argmin # noqa: F401
from .search import argsort # noqa: F401 from .search import argsort # noqa: F401
from .search import searchsorted # noqa: F401
from .search import topk # noqa: F401 from .search import topk # noqa: F401
from .search import where # noqa: F401 from .search import where # noqa: F401
from .search import index_select # noqa: F401 from .search import index_select # noqa: F401
......
...@@ -766,3 +766,75 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): ...@@ -766,3 +766,75 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None):
attrs=attrs) attrs=attrs)
indices.stop_gradient = True indices.stop_gradient = True
return values, indices return values, indices
def searchsorted(sorted_sequence,
values,
out_int32=False,
right=False,
name=None):
"""
This OP is used to find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`.
Args:
sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
values(Tensor): An input N-D tensor value with type int32, int64, float32, float64.
out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
The default value is False and it shows the lower bounds.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor(the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.
Examples:
.. code-block:: python
import paddle
sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11],
[2, 4, 6, 8, 10, 12]], dtype='int32')
values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32')
out1 = paddle.searchsorted(sorted_sequence, values)
print(out1)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [[1, 3, 4, 5],
# [1, 2, 4, 4]])
out2 = paddle.searchsorted(sorted_sequence, values, right=True)
print(out2)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [[2, 3, 5, 5],
# [1, 3, 4, 5]])
sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13])
out3 = paddle.searchsorted(sorted_sequence_1d, values)
print(out3)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [[1, 3, 4, 5],
# [1, 3, 4, 5]])
"""
if in_dygraph_mode():
return _C_ops.searchsorted(sorted_sequence, values, "out_int32",
out_int32, "right", right)
check_variable_and_dtype(sorted_sequence, 'SortedSequence',
['float32', 'float64', 'int32', 'int64'],
'paddle.searchsorted')
check_variable_and_dtype(values, 'Values',
['float32', 'float64', 'int32', 'int64'],
'paddle.searchsorted')
helper = LayerHelper('searchsorted', **locals())
out_type = 'int32' if out_int32 else 'int64'
out = helper.create_variable_for_type_inference(dtype=out_type)
helper.append_op(
type='searchsorted',
inputs={'SortedSequence': sorted_sequence,
"Values": values},
outputs={'Out': out},
attrs={"out_int32": out_int32,
"right": right})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册