未验证 提交 7701db37 编写于 作者: H hong 提交者: GitHub

Move one hot to phi (#39876)

* move one hot to phi; test=develop

* fix bugs; test=develop

* fix bugs; test=develop

* add infer meta; test=develop

* fix bugs; test=develop

* resolve confilct

* resolve confilct

* fix bug;

* fix error; test=develop

* update; test=develop

* polish code; test=develop

* add one api in eager mode; test=develop

* add one hot test; test=develop

* remove use less code; test=develop

* fix bug; test=develop

* polish code; test=develop

* polish code; test=develop
上级 c9f3ad03
......@@ -500,8 +500,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
} else if (ctx->HasInput(attr_name)) {
// convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>();
infer_meta_context.EmplaceBackAttr(val_int);
} else {
infer_meta_context.EmplaceBackAttr(-1);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Get value from variable only support int yet"));
}
}
}
......
......@@ -2250,41 +2250,62 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
auto attr_it = attrs_.find(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
if (attr_it == attrs_.end()) {
auto in_it = ctx.inputs.find(attr_names[i]);
if (in_it != ctx.inputs.end()) {
// get data from input
auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0]));
int32_t val_int = val.template to<int32_t>();
pt_kernel_context->EmplaceBackAttr(val_int);
} else {
PADDLE_THROW(platform::errors::NotFound(
"can not find attribute `%s` both in attribute and input ",
attr_names[i]));
}
} else {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int, attr_it->second));
}
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(float, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(bool, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int64_t, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::string, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
BOOST_GET_CONST(int, attr_it->second)));
pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
BOOST_GET_CONST(std::vector<int64_t>, attr_it->second));
} else if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -419,6 +419,17 @@ void BuildDygraphPhiKernelContext(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
}
} else if (ins.find(attr_names[i]) != ins.end()) {
// deal tensor attr here
auto& ins_vector = ins.at(attr_names[i]);
auto tensor_attr =
experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
int val = tensor_attr.template to<int>();
kernel_ctx->EmplaceBackAttr(val);
} else {
PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
......@@ -475,6 +486,7 @@ void BuildDygraphPhiKernelContext(
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
......
......@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/one_hot_v2_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -22,26 +26,6 @@ namespace operators {
class OneHotV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "one_hot_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "one_hot_v2");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 1,
platform::errors::InvalidArgument(
"Rank of Input(X) should be at least 1."));
int depth = ctx->Attrs().Get<int>("depth");
if (ctx->HasInput("depth_tensor")) {
depth = -1;
}
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -52,7 +36,7 @@ class OneHotV2Op : public framework::OperatorWithKernel {
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "depth_tensor") {
return expected_kernel_type;
......@@ -114,10 +98,12 @@ Out is a LoDTensor:
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(one_hot_v2, OneHotInferShapeFunctor,
PD_INFER_META(phi::OneHotRawInferMeta));
REGISTER_OPERATOR(
one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
one_hot_v2, ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
OneHotInferShapeFunctor);
......@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/one_hot_v2_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T>
class OneHotV2NPUKernel : public framework::OpKernel<T> {
......
......@@ -55,6 +55,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"expand_grad",
"expand_as_grad",
"sum",
"one_hot",
"sum_grad",
"top_k",
"top_k_grad"});
......
......@@ -72,6 +72,10 @@ void MetaTensor::set_layout(DataLayout layout) {
}
void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
if (meta_tensor.lod().size() == 0) {
// no need share
return;
}
if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
meta_tensor.lod();
......
......@@ -1602,6 +1602,43 @@ void UnfoldInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims));
}
void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(dtype);
}
void OneHotInferMeta(const MetaTensor& x,
const Scalar& depth_t,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
int depth = depth_t.to<int>();
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(phi::DataType::FLOAT32);
}
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
auto rank = condition.dims().size();
PADDLE_ENFORCE_GE(
......
......@@ -228,6 +228,14 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
MetaTensor* out);
void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out);
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
} // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,23 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename DeviceContext, typename InT>
struct OneHotV2OpFunctor {
const framework::LoDTensor* in_;
framework::LoDTensor* out_;
const DenseTensor* in_;
DenseTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;
OneHotV2OpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
int depth, const DeviceContext& ctx,
OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
......@@ -40,8 +42,8 @@ struct OneHotV2OpFunctor {
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
phi::funcs::set_constant(ctx_, out_, 0.0);
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0);
if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
......@@ -52,51 +54,46 @@ struct OneHotV2OpFunctor {
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i], 0,
platform::errors::InvalidArgument(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i], depth_,
platform::errors::InvalidArgument(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i], depth_));
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
}
};
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class OneHotV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
bool allow_out_of_range = context.Attr<bool>("allow_out_of_range");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
depth = depth_data[0];
auto out_dims = out->dims();
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotV2OpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>(),
allow_out_of_range));
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
};
} // namespace operators
} // namespace paddle
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(
&x, out, depth, dev_ctx, allow_out_of_range));
}
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,17 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/one_hot_v2_op.h"
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
const int64_t numel, const int depth) {
__global__ void FillOutputKernel(const InT* p_in_data,
OutT* p_out_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
......@@ -31,13 +36,14 @@ __global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
template <typename DeviceContext, typename InT>
struct OneHotV2OpCUDAFunctor {
const framework::LoDTensor* in_;
framework::LoDTensor* out_;
const DenseTensor* in_;
DenseTensor* out_;
const DeviceContext& ctx_;
int depth_;
OneHotV2OpCUDAFunctor(const framework::LoDTensor* in,
framework::LoDTensor* out, int depth,
OneHotV2OpCUDAFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
......@@ -45,56 +51,36 @@ struct OneHotV2OpCUDAFunctor {
void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
auto stream = ctx_.stream();
phi::funcs::set_constant(ctx_, out_, 0.0);
funcs::set_constant(ctx_, out_, 0.0);
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
p_in_data, p_out_data, numel, depth_);
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(p_in_data, p_out_data, numel, depth_);
}
};
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class OneHotV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = -1;
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<framework::Tensor>("depth_tensor");
if (platform::is_gpu_place(depth_tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*depth_tensor, platform::CPUPlace(),
&temp);
depth = *temp.data<int32_t>();
} else {
depth = *depth_tensor->data<int32_t>();
}
auto out_dims = out->dims();
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
} else {
depth = context.Attr<int>("depth");
}
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotV2OpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out) {
auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
};
} // namespace operators
} // namespace paddle
phi::VisitDataType(
dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth, dev_ctx));
}
} // namespace phi
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
one_hot_v2,
ops::OneHotV2CUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::OneHotV2CUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& num_classes_s,
DenseTensor* out) {
int num_classes = num_classes_s.to<int>();
OneHotRawKernel<T>(
dev_ctx, x, num_classes, phi::DataType::FLOAT32, false, out);
}
} // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& num_classes,
DenseTensor* out);
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature OneHotOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("depth_tensor")) {
return KernelSignature("one_hot_raw",
{"X"},
{"depth_tensor", "dtype", "allow_out_of_range"},
{"Out"});
} else {
return KernelSignature("one_hot_raw",
{"X"},
{"depth", "dtype", "allow_out_of_range"},
{"Out"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot);
PD_REGISTER_ARG_MAPPING_FN(one_hot_v2, phi::OneHotOpArgumentMapping);
......@@ -52,6 +52,12 @@ final_state_name_mapping = {
"axis1": "axis1",
"axis2": "axis2",
"out": "Out",
},
"one_hot": {
"final_op_name": "final_state_one_hot",
"x": "X",
"num_class": "depth",
"out": "Out",
}
}
......
......@@ -22,7 +22,8 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.framework import Program, program_guard
from paddle.framework import _in_eager_mode
from paddle.fluid.framework import Program, program_guard, _test_eager_guard
class TestOneHotOp(OpTest):
......@@ -45,7 +46,7 @@ class TestOneHotOp(OpTest):
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
class TestOneHotOp_attr(OpTest):
......@@ -68,7 +69,7 @@ class TestOneHotOp_attr(OpTest):
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
class TestOneHotOp_default_dtype(OpTest):
......@@ -91,7 +92,7 @@ class TestOneHotOp_default_dtype(OpTest):
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
class TestOneHotOp_default_dtype_attr(OpTest):
......@@ -114,7 +115,7 @@ class TestOneHotOp_default_dtype_attr(OpTest):
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
class TestOneHotOp_out_of_range(OpTest):
......@@ -132,7 +133,7 @@ class TestOneHotOp_out_of_range(OpTest):
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
class TestOneHotOp_exception(unittest.TestCase):
......@@ -190,6 +191,12 @@ class TestOneHotOpApi(unittest.TestCase):
one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth)
one_hot_label = paddle.nn.functional.one_hot(
fluid.dygraph.to_variable(label), depth)
with _test_eager_guard():
one_hot_label = paddle.nn.functional.one_hot(
paddle.to_tensor(label), depth)
def _run(self, depth):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth)
......
......@@ -19,6 +19,7 @@ from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.framework import _in_eager_mode
__all__ = []
......@@ -87,6 +88,8 @@ def one_hot(x, num_classes, name=None):
"""
if in_dynamic_mode():
if _in_eager_mode():
return _C_ops.final_state_one_hot(x, num_classes)
return _C_ops.one_hot_v2(x, 'depth', num_classes, 'allow_out_of_range',
False)
else:
......
......@@ -204,6 +204,15 @@
output : Tensor
invoke : full_like(x, 0, dtype, place)
- api : one_hot
args : (Tensor x, Scalar num_classes)
output : Tensor
infer_meta :
func : OneHotInferMeta
kernel :
func : one_hot
- api : digamma
args : (Tensor x)
output : Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册