未验证 提交 85f8fd9b 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move searchsorted kernel to phi (#40520)

上级 1a32391c
......@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker);
REGISTER_OP_CPU_KERNEL(
searchsorted,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, float>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, double>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int>,
ops::SearchSortedKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL(searchsorted,
CPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL(searchsorted,
GPU,
ALL_LAYOUT,
phi::SearchsortedKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -16,16 +16,11 @@
#include <math.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace phi {
template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute {
......@@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute {
static HOSTDEVICE bool IsInf(int64_t x) { return false; }
HOSTDEVICE GpuAndCpuSearchSortedCompute(const T1* sequence_data,
const T2* value_data, bool right,
const T2* value_data,
bool right,
bool is_1d_boundaries,
int64_t val_size, int64_t seq_size,
int64_t val_size,
int64_t seq_size,
OutType* out_data)
: sequence_data_(sequence_data),
value_data_(value_data),
......@@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute {
OutType* out_data_;
};
template <typename DeviceContext, typename T1, typename OutType>
template <typename Context, typename T1, typename OutType>
class SearchSortedFunctor {
public:
SearchSortedFunctor(const framework::ExecutionContext& context,
const framework::Tensor* sorted_sequence,
const framework::Tensor* value, bool right,
SearchSortedFunctor(const Context& context,
const DenseTensor* sorted_sequence,
const DenseTensor* value,
bool right,
OutType* out_data)
: context_(context),
sorted_sequence_(sorted_sequence),
......@@ -121,74 +119,73 @@ class SearchSortedFunctor {
void apply() {
const T1* sequence_data = sorted_sequence_->data<T1>();
const T2* value_data = value_->data<T2>();
const framework::DDim& seq_dims = sorted_sequence_->dims();
const framework::DDim& val_dims = value_->dims();
const phi::DDim& seq_dims = sorted_sequence_->dims();
const phi::DDim& val_dims = value_->dims();
bool is_1d_boundaries = seq_dims.size() == 1;
int64_t val_size = val_dims[val_dims.size() - 1];
int64_t seq_size = seq_dims[seq_dims.size() - 1];
auto& dev_ctx = context_.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, value_->numel());
funcs::ForRange<Context> for_range(context_, value_->numel());
GpuAndCpuSearchSortedCompute<T1, T2, OutType>
gpu_and_cpu_search_sorted_compute(sequence_data, value_data, right_,
is_1d_boundaries, val_size, seq_size,
gpu_and_cpu_search_sorted_compute(sequence_data,
value_data,
right_,
is_1d_boundaries,
val_size,
seq_size,
out_data_);
for_range(gpu_and_cpu_search_sorted_compute);
}
private:
const framework::ExecutionContext& context_;
const framework::Tensor* sorted_sequence_;
const framework::Tensor* value_;
const Context& context_;
const DenseTensor* sorted_sequence_;
const DenseTensor* value_;
bool right_;
OutType* out_data_;
};
template <typename Visitor>
static void VisitDataType(framework::proto::VarType::Type type,
Visitor visitor) {
if (type == framework::proto::VarType::FP32) {
static void VisitDataType(DataType type, Visitor visitor) {
if (type == DataType::FLOAT32) {
visitor.template apply<float>();
} else if (type == framework::proto::VarType::FP64) {
} else if (type == DataType::FLOAT64) {
visitor.template apply<double>();
} else if (type == framework::proto::VarType::INT32) {
} else if (type == DataType::INT32) {
visitor.template apply<int>();
} else if (type == framework::proto::VarType::INT64) {
} else if (type == DataType::INT64) {
visitor.template apply<int64_t>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"The recieved values data type %s can not meet input requirements. "
"Because the given values data type of searchsorted operators must be "
"float32, float64, int32 or int64. Please input appropriate "
"sorted_sequence again! ",
framework::DataTypeToString(type)));
type));
}
}
template <typename DeviceContext, typename T>
class SearchSortedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* sorted_sequence = context.Input<Tensor>("SortedSequence");
auto* value = context.Input<Tensor>("Values");
bool out_int32 = context.Attr<bool>("out_int32");
bool right = context.Attr<bool>("right");
auto* out = context.Output<Tensor>("Out");
template <typename T, typename Context>
void SearchsortedKernel(const Context& ctx,
const DenseTensor& sorted_sequence,
const DenseTensor& value,
bool out_int32,
bool right,
DenseTensor* out) {
if (out_int32) {
int* out_data = out->mutable_data<int>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor);
ctx.template Alloc<int>(out);
int* out_data = out->data<int>();
SearchSortedFunctor<Context, T, int> functor(
ctx, &sorted_sequence, &value, right, out_data);
VisitDataType(value.dtype(), functor);
} else {
int64_t* out_data = out->mutable_data<int64_t>(context.GetPlace());
SearchSortedFunctor<DeviceContext, T, int64_t> functor(
context, sorted_sequence, value, right, out_data);
VisitDataType(framework::TransToProtoVarType(value->dtype()), functor);
}
ctx.template Alloc<int64_t>(out);
int64_t* out_data = out->data<int64_t>();
SearchSortedFunctor<Context, T, int64_t> functor(
ctx, &sorted_sequence, &value, right, out_data);
VisitDataType(value.dtype(), functor);
}
};
}
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#pragma once
REGISTER_OP_CUDA_KERNEL(
searchsorted, ops::SearchSortedKernel<plat::CUDADeviceContext, float>,
ops::SearchSortedKernel<plat::CUDADeviceContext, double>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int>,
ops::SearchSortedKernel<plat::CUDADeviceContext, int64_t>);
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SearchsortedKernel(const Context& ctx,
const DenseTensor& sorted_sequence,
const DenseTensor& value,
bool out_int32,
bool right,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册