From 34122e3ec281d59c15cc2e6cc1f7f27bc5bc2431 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Mon, 15 May 2023 10:11:45 +0800 Subject: [PATCH] move OneHotRawKernel to legacy (#53200) * move OneHotRawKernel to legacy * fix --- paddle/phi/kernels/cpu/one_hot_kernel.cc | 41 +++++--- paddle/phi/kernels/gpu/one_hot_kernel.cu | 59 ++++-------- .../phi/kernels/legacy/cpu/one_hot_kernel.cc | 88 +++++++++++++++++ .../phi/kernels/legacy/gpu/one_hot_kernel.cu | 96 +++++++++++++++++++ .../phi/kernels/legacy/xpu/one_hot_kernel.cc | 68 +++++++++++++ paddle/phi/kernels/one_hot_kernel.cc | 47 --------- paddle/phi/kernels/one_hot_kernel.h | 8 -- paddle/phi/kernels/xpu/one_hot_kernel.cc | 47 +++------ 8 files changed, 312 insertions(+), 142 deletions(-) create mode 100644 paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc create mode 100644 paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu create mode 100644 paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc delete mode 100644 paddle/phi/kernels/one_hot_kernel.cc diff --git a/paddle/phi/kernels/cpu/one_hot_kernel.cc b/paddle/phi/kernels/cpu/one_hot_kernel.cc index 9d1daaf065f..0958e2c02b4 100644 --- a/paddle/phi/kernels/cpu/one_hot_kernel.cc +++ b/paddle/phi/kernels/cpu/one_hot_kernel.cc @@ -63,12 +63,10 @@ struct OneHotV2OpFunctor { }; template -void OneHotRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& depth, - DataType dtype, - bool allow_out_of_range, - DenseTensor* out) { +void OneHotKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DenseTensor* out) { auto depth_v = depth.to(); auto out_dims = out->dims(); if (out_dims[out_dims.size() - 1] == -1) { @@ -76,13 +74,34 @@ void OneHotRawKernel(const Context& dev_ctx, out->Resize(out_dims); } - phi::VisitDataType(dtype, - OneHotV2OpFunctor(&x, out, depth_v, dev_ctx)); + auto* p_in_data = x.data(); + auto numel = x.numel(); + auto* p_out_data = dev_ctx.template Alloc(out); + funcs::set_constant(dev_ctx, out, 0.0); + + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE( + p_in_data[i], + 0, + phi::errors::InvalidArgument( + "Illegal index value, Input(input) value should be at least 0, " + "but received input (%d) less than 0", + p_in_data[i])); + PADDLE_ENFORCE_LT( + p_in_data[i], + depth_v, + phi::errors::InvalidArgument( + "Illegal index value, Input(input) value should be less than " + "Input(depth), " + "but received input (%d) not less than depth (%d)", + p_in_data[i], + depth_v)); + *(p_out_data + i * depth_v + p_in_data[i]) = 1.0; + } } } // namespace phi -PD_REGISTER_KERNEL( - one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); } diff --git a/paddle/phi/kernels/gpu/one_hot_kernel.cu b/paddle/phi/kernels/gpu/one_hot_kernel.cu index 45d208498a5..63084585d80 100644 --- a/paddle/phi/kernels/gpu/one_hot_kernel.cu +++ b/paddle/phi/kernels/gpu/one_hot_kernel.cu @@ -40,43 +40,11 @@ __global__ void FillOutputKernel(const InT* p_in_data, } } -template -struct OneHotV2OpCUDAFunctor { - const DenseTensor* in_; - DenseTensor* out_; - const DeviceContext& ctx_; - int depth_; - - OneHotV2OpCUDAFunctor(const DenseTensor* in, - DenseTensor* out, - int depth, - const DeviceContext& ctx) - : in_(in), out_(out), depth_(depth), ctx_(ctx) {} - - template - void apply() const { - auto* p_in_data = in_->data(); - auto numel = in_->numel(); - auto* p_out_data = ctx_.template Alloc(out_); - auto stream = ctx_.stream(); - funcs::set_constant(ctx_, out_, 0.0); - - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel); - - FillOutputKernel<<>>(p_in_data, p_out_data, numel, depth_); - } -}; - template -void OneHotRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& depth, - DataType dtype, - bool allow_out_of_range, - DenseTensor* out) { +void OneHotKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DenseTensor* out) { auto depth_v = depth.to(); auto out_dims = out->dims(); if (out_dims[out_dims.size() - 1] == -1) { @@ -84,13 +52,22 @@ void OneHotRawKernel(const Context& dev_ctx, out->Resize(out_dims); } - phi::VisitDataType( - dtype, OneHotV2OpCUDAFunctor(&x, out, depth_v, dev_ctx)); + auto* p_in_data = x.data(); + auto numel = x.numel(); + auto* p_out_data = dev_ctx.template Alloc(out); + auto stream = dev_ctx.stream(); + funcs::set_constant(dev_ctx, out, 0.0); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + + FillOutputKernel<<>>(p_in_data, p_out_data, numel, depth_v); } } // namespace phi -PD_REGISTER_KERNEL( - one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); } diff --git a/paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc b/paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc new file mode 100644 index 00000000000..040a8559914 --- /dev/null +++ b/paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct OneHotV2OpFunctor { + const DenseTensor* in_; + DenseTensor* out_; + int depth_; + const DeviceContext& ctx_; + + OneHotV2OpFunctor(const DenseTensor* in, + DenseTensor* out, + int depth, + const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = ctx_.template Alloc(out_); + funcs::set_constant(ctx_, out_, 0.0); + + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE( + p_in_data[i], + 0, + phi::errors::InvalidArgument( + "Illegal index value, Input(input) value should be at least 0, " + "but received input (%d) less than 0", + p_in_data[i])); + PADDLE_ENFORCE_LT( + p_in_data[i], + depth_, + phi::errors::InvalidArgument( + "Illegal index value, Input(input) value should be less than " + "Input(depth), " + "but received input (%d) not less than depth (%d)", + p_in_data[i], + depth_)); + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } +}; + +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out) { + auto depth_v = depth.to(); + auto out_dims = out->dims(); + if (out_dims[out_dims.size() - 1] == -1) { + out_dims[out_dims.size() - 1] = depth_v; + out->Resize(out_dims); + } + + phi::VisitDataType(dtype, + OneHotV2OpFunctor(&x, out, depth_v, dev_ctx)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu b/paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu new file mode 100644 index 00000000000..c64f2e2d755 --- /dev/null +++ b/paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using phi::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillOutputKernel(const InT* p_in_data, + OutT* p_out_data, + const int64_t numel, + const int depth) { + CUDA_KERNEL_LOOP_TYPE(idx, numel, int64_t) { + PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth, + "Illegal index value, Input(input) value should be " + "greater than or equal to 0, and less than depth [%d], " + "but received [%lld].", + depth, + p_in_data[idx]); + + *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; + } +} + +template +struct OneHotV2OpCUDAFunctor { + const DenseTensor* in_; + DenseTensor* out_; + const DeviceContext& ctx_; + int depth_; + + OneHotV2OpCUDAFunctor(const DenseTensor* in, + DenseTensor* out, + int depth, + const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = ctx_.template Alloc(out_); + auto stream = ctx_.stream(); + funcs::set_constant(ctx_, out_, 0.0); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel); + + FillOutputKernel<<>>(p_in_data, p_out_data, numel, depth_); + } +}; + +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out) { + auto depth_v = depth.to(); + auto out_dims = out->dims(); + if (out_dims[out_dims.size() - 1] == -1) { + out_dims[out_dims.size() - 1] = depth_v; + out->Resize(out_dims); + } + + phi::VisitDataType( + dtype, OneHotV2OpCUDAFunctor(&x, out, depth_v, dev_ctx)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc b/paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc new file mode 100644 index 00000000000..02edbd12843 --- /dev/null +++ b/paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +template +struct OneHotV2OpFunctor { + const DenseTensor* in_; + DenseTensor* out_; + int depth_; + const Context& ctx_; + + OneHotV2OpFunctor(const DenseTensor* in, + DenseTensor* out, + int depth, + const Context& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = ctx_.template Alloc(out_); + int r = xpu::one_hot( + ctx_.x_context(), p_in_data, p_out_data, numel, depth_, 1.0, 0.0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot"); + } +}; + +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out) { + auto depth_v = depth.to(); + auto out_dims = out->dims(); + if (out_dims[out_dims.size() - 1] == -1) { + out_dims[out_dims.size() - 1] = depth_v; + out->Resize(out_dims); + } + phi::VisitDataType(dtype, + OneHotV2OpFunctor(&x, out, depth_v, dev_ctx)); +} +} // namespace phi + +PD_REGISTER_KERNEL( + one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/one_hot_kernel.cc b/paddle/phi/kernels/one_hot_kernel.cc deleted file mode 100644 index 3ad4799ef8a..00000000000 --- a/paddle/phi/kernels/one_hot_kernel.cc +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/one_hot_kernel.h" - -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void OneHotKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& num_classes_s, - DenseTensor* out) { - OneHotRawKernel( - dev_ctx, x, num_classes_s, phi::DataType::FLOAT32, false, out); -} - -} // namespace phi - -PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); -} -#endif - -#ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); -} -#endif diff --git a/paddle/phi/kernels/one_hot_kernel.h b/paddle/phi/kernels/one_hot_kernel.h index 79af88473b2..b3b0bf3e9e4 100644 --- a/paddle/phi/kernels/one_hot_kernel.h +++ b/paddle/phi/kernels/one_hot_kernel.h @@ -25,12 +25,4 @@ void OneHotKernel(const Context& dev_ctx, const Scalar& num_classes, DenseTensor* out); -template -void OneHotRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& depth, - DataType dtype, - bool allow_out_of_range, - DenseTensor* out); - } // namespace phi diff --git a/paddle/phi/kernels/xpu/one_hot_kernel.cc b/paddle/phi/kernels/xpu/one_hot_kernel.cc index c0b17b122ea..ad96d4858f7 100644 --- a/paddle/phi/kernels/xpu/one_hot_kernel.cc +++ b/paddle/phi/kernels/xpu/one_hot_kernel.cc @@ -19,49 +19,26 @@ #include "paddle/phi/core/utils/data_type.h" namespace phi { -template -struct OneHotV2OpFunctor { - const DenseTensor* in_; - DenseTensor* out_; - int depth_; - const Context& ctx_; - - OneHotV2OpFunctor(const DenseTensor* in, - DenseTensor* out, - int depth, - const Context& ctx) - : in_(in), out_(out), depth_(depth), ctx_(ctx) {} - - template - void apply() const { - auto* p_in_data = in_->data(); - auto numel = in_->numel(); - auto* p_out_data = ctx_.template Alloc(out_); - int r = xpu::one_hot( - ctx_.x_context(), p_in_data, p_out_data, numel, depth_, 1.0, 0.0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot"); - } -}; - template -void OneHotRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const Scalar& depth, - DataType dtype, - bool allow_out_of_range, - DenseTensor* out) { +void OneHotKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& depth, + DenseTensor* out) { auto depth_v = depth.to(); auto out_dims = out->dims(); if (out_dims[out_dims.size() - 1] == -1) { out_dims[out_dims.size() - 1] = depth_v; out->Resize(out_dims); } - phi::VisitDataType(dtype, - OneHotV2OpFunctor(&x, out, depth_v, dev_ctx)); + auto* p_in_data = x.data(); + auto numel = x.numel(); + auto* p_out_data = dev_ctx.template Alloc(out); + int r = xpu::one_hot( + dev_ctx.x_context(), p_in_data, p_out_data, numel, depth_v, 1.0, 0.0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot"); } } // namespace phi -PD_REGISTER_KERNEL( - one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { - kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); } -- GitLab