未验证 提交 34122e3e 编写于 作者: Z zhangyuqin1998 提交者: GitHub

move OneHotRawKernel to legacy (#53200)

* move OneHotRawKernel to legacy

* fix
上级 3e90a461
......@@ -63,12 +63,10 @@ struct OneHotV2OpFunctor {
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
......@@ -76,13 +74,34 @@ void OneHotRawKernel(const Context& dev_ctx,
out->Resize(out_dims);
}
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(&x, out, depth_v, dev_ctx));
auto* p_in_data = x.data<T>();
auto numel = x.numel();
auto* p_out_data = dev_ctx.template Alloc<float>(out);
funcs::set_constant(dev_ctx, out, 0.0);
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_v,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_v));
*(p_out_data + i * depth_v + p_in_data[i]) = 1.0;
}
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
......@@ -40,43 +40,11 @@ __global__ void FillOutputKernel(const InT* p_in_data,
}
}
template <typename DeviceContext, typename InT>
struct OneHotV2OpCUDAFunctor {
const DenseTensor* in_;
DenseTensor* out_;
const DeviceContext& ctx_;
int depth_;
OneHotV2OpCUDAFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
auto stream = ctx_.stream();
funcs::set_constant(ctx_, out_, 0.0);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel);
FillOutputKernel<<<config.block_per_grid,
config.thread_per_block,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
}
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
......@@ -84,13 +52,22 @@ void OneHotRawKernel(const Context& dev_ctx,
out->Resize(out_dims);
}
phi::VisitDataType(
dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth_v, dev_ctx));
auto* p_in_data = x.data<T>();
auto numel = x.numel();
auto* p_out_data = dev_ctx.template Alloc<float>(out);
auto stream = dev_ctx.stream();
funcs::set_constant(dev_ctx, out, 0.0);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
FillOutputKernel<<<config.block_per_grid,
config.thread_per_block,
0,
stream>>>(p_in_data, p_out_data, numel, depth_v);
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename DeviceContext, typename InT>
struct OneHotV2OpFunctor {
const DenseTensor* in_;
DenseTensor* out_;
int depth_;
const DeviceContext& ctx_;
OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0);
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims);
}
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(&x, out, depth_v, dev_ctx));
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data,
OutT* p_out_data,
const int64_t numel,
const int depth) {
CUDA_KERNEL_LOOP_TYPE(idx, numel, int64_t) {
PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth,
"Illegal index value, Input(input) value should be "
"greater than or equal to 0, and less than depth [%d], "
"but received [%lld].",
depth,
p_in_data[idx]);
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}
template <typename DeviceContext, typename InT>
struct OneHotV2OpCUDAFunctor {
const DenseTensor* in_;
DenseTensor* out_;
const DeviceContext& ctx_;
int depth_;
OneHotV2OpCUDAFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
auto stream = ctx_.stream();
funcs::set_constant(ctx_, out_, 0.0);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel);
FillOutputKernel<<<config.block_per_grid,
config.thread_per_block,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
}
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims);
}
phi::VisitDataType(
dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth_v, dev_ctx));
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename Context, typename InT>
struct OneHotV2OpFunctor {
const DenseTensor* in_;
DenseTensor* out_;
int depth_;
const Context& ctx_;
OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const Context& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = ctx_.template Alloc<float>(out_);
int r = xpu::one_hot<InT>(
ctx_.x_context(), p_in_data, p_out_data, numel, depth_, 1.0, 0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot");
}
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims);
}
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(&x, out, depth_v, dev_ctx));
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& num_classes_s,
DenseTensor* out) {
OneHotRawKernel<T>(
dev_ctx, x, num_classes_s, phi::DataType::FLOAT32, false, out);
}
} // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#endif
......@@ -25,12 +25,4 @@ void OneHotKernel(const Context& dev_ctx,
const Scalar& num_classes,
DenseTensor* out);
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out);
} // namespace phi
......@@ -19,49 +19,26 @@
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename Context, typename InT>
struct OneHotV2OpFunctor {
const DenseTensor* in_;
DenseTensor* out_;
int depth_;
const Context& ctx_;
OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const Context& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = ctx_.template Alloc<float>(out_);
int r = xpu::one_hot<InT>(
ctx_.x_context(), p_in_data, p_out_data, numel, depth_, 1.0, 0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot");
}
};
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& depth,
DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims);
}
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(&x, out, depth_v, dev_ctx));
auto* p_in_data = x.data<T>();
auto numel = x.numel();
auto* p_out_data = dev_ctx.template Alloc<float>(out);
int r = xpu::one_hot<T>(
dev_ctx.x_context(), p_in_data, p_out_data, numel, depth_v, 1.0, 0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "one_hot");
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册