未验证 提交 85f8fd9b 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move searchsorted kernel to phi (#40520)

上级 1a32391c
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker); REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker);
REGISTER_OP_CPU_KERNEL(
searchsorted,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, float>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, double>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL(searchsorted,
CPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL(searchsorted,
GPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -16,16 +16,11 @@ ...@@ -16,16 +16,11 @@
#include <math.h> #include <math.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace paddle { namespace phi {
namespace operators {
using Tensor = framework::Tensor;
template <typename T1, typename T2, typename OutType> template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute { class GpuAndCpuSearchSortedCompute {
...@@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute { ...@@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute {
static HOSTDEVICE bool IsInf(int64_t x) { return false; } static HOSTDEVICE bool IsInf(int64_t x) { return false; }
HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data, HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data,
const T2* value_data, bool right, const T2* value_data,
bool right,
bool is_1d_boundaries, bool is_1d_boundaries,
int64_t val_size, int64_t seq_size, int64_t val_size,
int64_t seq_size,
OutType* out_data) OutType* out_data)
: sequence_data_(sequence_data), : sequence_data_(sequence_data),
value_data_(value_data), value_data_(value_data),
...@@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute { ...@@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute {
OutType* out_data_; OutType* out_data_;
}; };
template <typename DeviceContext, typename T1, typename OutType> template <typename Context, typename T1, typename OutType>
class SearchSortedFunctor { class SearchSortedFunctor {
public: public:
SearchSortedFunctor(const framework::ExecutionContext& context, SearchSortedFunctor(const Context& context,
const framework::Tensor* sorted_sequence, const DenseTensor* sorted_sequence,
const framework::Tensor* value, bool right, const DenseTensor* value,
bool right,
OutType* out_data) OutType* out_data)
: context_(context), : context_(context),
sorted_sequence_(sorted_sequence), sorted_sequence_(sorted_sequence),
...@@ -121,74 +119,73 @@ class SearchSortedFunctor { ...@@ -121,74 +119,73 @@ class SearchSortedFunctor {
void apply() { void apply() {
const T1* sequence_data = sorted_sequence_->data<T1>(); const T1* sequence_data = sorted_sequence_->data<T1>();
const T2* value_data = value_->data<T2>(); const T2* value_data = value_->data<T2>();
const framework::DDim& seq_dims = sorted_sequence_->dims(); const phi::DDim& seq_dims = sorted_sequence_->dims();
const framework::DDim& val_dims = value_->dims(); const phi::DDim& val_dims = value_->dims();
bool is_1d_boundaries = seq_dims.size() == 1; bool is_1d_boundaries = seq_dims.size() == 1;
int64_t val_size = val_dims[val_dims.size() - 1]; int64_t val_size = val_dims[val_dims.size() - 1];
int64_t seq_size = seq_dims[seq_dims.size() - 1]; int64_t seq_size = seq_dims[seq_dims.size() - 1];
auto& dev_ctx = context_.template device_context<DeviceContext>(); funcs::ForRange<Context> for_range(context_, value_->numel());
platform::ForRange<DeviceContext> for_range(dev_ctx, value_->numel());
GpuAndCpuSearchSortedCompute<T1, T2, OutType> GpuAndCpuSearchSortedCompute<T1, T2, OutType>
gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_, gpu_and_cpu_search_sorted_compute(sequence_data,
is_1d_boundaries, val_size, seq_size, value_data,
right_,
is_1d_boundaries,
val_size,
seq_size,
out_data_); out_data_);
for_range(gpu_and_cpu_search_sorted_compute); for_range(gpu_and_cpu_search_sorted_compute);
} }
private: private:
const framework::ExecutionContext& context_; const Context& context_;
const framework::Tensor* sorted_sequence_; const DenseTensor* sorted_sequence_;
const framework::Tensor* value_; const DenseTensor* value_;
bool right_; bool right_;
OutType* out_data_; OutType* out_data_;
}; };
template <typename Visitor> template <typename Visitor>
static void VisitDataType(framework::proto::VarType::Type type, static void VisitDataType(DataType type, Visitor visitor) {
Visitor visitor) { if (type == DataType::FLOAT32) {
if (type == framework::proto::VarType::FP32) {
visitor.template apply<float>(); visitor.template apply<float>();
} else if (type == framework::proto::VarType::FP64) { } else if (type == DataType::FLOAT64) {
visitor.template apply<double>(); visitor.template apply<double>();
} else if (type == framework::proto::VarType::INT32) { } else if (type == DataType::INT32) {
visitor.template apply<int>(); visitor.template apply<int>();
} else if (type == framework::proto::VarType::INT64) { } else if (type == DataType::INT64) {
visitor.template apply<int64_t>(); visitor.template apply<int64_t>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(errors::InvalidArgument(
"The recieved values data type %s can not meet input requirements. " "The recieved values data type %s can not meet input requirements. "
"Because the given values data type of searchsorted operators must be " "Because the given values data type of searchsorted operators must be "
"float32, float64, int32 or int64. Please input appropriate " "float32, float64, int32 or int64. Please input appropriate "
"sorted_sequence again! ", "sorted_sequence again! ",
framework::DataTypeToString(type))); type));
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename Context>
class SearchSortedKernel : public framework::OpKernel<T> { void SearchsortedKernel(const Context& ctx,
public: const DenseTensor& sorted_sequence,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& value,
auto* sorted_sequence = context.Input<Tensor>("SortedSequence"); bool out_int32,
auto* value = context.Input<Tensor>("Values"); bool right,
bool out_int32 = context.Attr<bool>("out_int32"); DenseTensor* out) {
bool right = context.Attr<bool>("right"); if (out_int32) {
auto* out = context.Output<Tensor>("Out"); ctx.template Alloc<int>(out);
int* out_data = out->data<int>();
if (out_int32) { SearchSortedFunctor<Context, T, int> functor(
int* out_data = out->mutable_data<int>(context.GetPlace()); ctx, &sorted_sequence, &value, right, out_data);
SearchSortedFunctor<DeviceContext, T, int> functor( VisitDataType(value.dtype(), functor);
context, sorted_sequence, value, right, out_data); } else {
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor); ctx.template Alloc<int64_t>(out);
} else { int64_t* out_data = out->data<int64_t>();
int64_t* out_data = out->mutable_data<int64_t>(context.GetPlace()); SearchSortedFunctor<Context, T, int64_t> functor(
SearchSortedFunctor<DeviceContext, T, int64_t> functor( ctx, &sorted_sequence, &value, right, out_data);
context, sorted_sequence, value, right, out_data); VisitDataType(value.dtype(), functor);
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor);
}
} }
}; }
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,12 +12,18 @@ ...@@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h" #pragma once
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( #include "paddle/phi/core/dense_tensor.h"
searchsorted, ops::SearchSortedKernel<plat::CUDADeviceContext, float>,
ops::SearchSortedKernel<plat::CUDADeviceContext, double>, namespace phi {
ops::SearchSortedKernel<plat::CUDADeviceContext, int>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int64_t>); template <typename T, typename Context>
void SearchsortedKernel(const Context& ctx,
const DenseTensor& sorted_sequence,
const DenseTensor& value,
bool out_int32,
bool right,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册