未验证 提交 7701db37 编写于 作者: H hong 提交者: GitHub

Move one hot to phi (#39876)

* move one hot to phi; test=develop

* fix bugs; test=develop

* fix bugs; test=develop

* add infer meta; test=develop

* fix bugs; test=develop

* resolve confilct

* resolve confilct

* fix bug;

* fix error; test=develop

* update; test=develop

* polish code; test=develop

* add one api in eager mode; test=develop

* add one hot test; test=develop

* remove use less code; test=develop

* fix bug; test=develop

* polish code; test=develop

* polish code; test=develop
上级 c9f3ad03
...@@ -500,8 +500,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -500,8 +500,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call " "Unsupported attribute type is received when call "
"InferShapeFunctor.")); "InferShapeFunctor."));
} }
} else if (ctx->HasInput(attr_name)) {
// convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>();
infer_meta_context.EmplaceBackAttr(val_int);
} else { } else {
// do nothing infer_meta_context.EmplaceBackAttr(-1);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Get value from variable only support int yet"));
}
} }
} }
......
...@@ -2250,41 +2250,62 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2250,41 +2250,62 @@ void OperatorWithKernel::BuildPhiKernelContext(
} }
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]); auto attr_it = attrs_.find(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); if (attr_it == attrs_.end()) {
auto in_it = ctx.inputs.find(attr_names[i]);
if (in_it != ctx.inputs.end()) {
// get data from input
auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0]));
int32_t val_int = val.template to<int32_t>();
pt_kernel_context->EmplaceBackAttr(val_int);
} else {
PADDLE_THROW(platform::errors::NotFound(
"can not find attribute `%s` both in attribute and input ",
attr_names[i]));
}
} else {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int, attr_it->second));
}
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(float, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(bool, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int64_t, attr_it->second));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::string, attr_it->second));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) { std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPhiDataType( auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr_it->second)));
pt_kernel_context->EmplaceBackAttr(data_type); pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr)); BOOST_GET_CONST(std::vector<int64_t>, attr_it->second));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end()); vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr); pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) { std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr); pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -419,6 +419,17 @@ void BuildDygraphPhiKernelContext( ...@@ -419,6 +419,17 @@ void BuildDygraphPhiKernelContext(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
} }
} else if (ins.find(attr_names[i]) != ins.end()) {
// deal tensor attr here
auto& ins_vector = ins.at(attr_names[i]);
auto tensor_attr =
experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
int val = tensor_attr.template to<int>();
kernel_ctx->EmplaceBackAttr(val);
} else {
PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
}
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) { std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
...@@ -475,6 +486,7 @@ void BuildDygraphPhiKernelContext( ...@@ -475,6 +486,7 @@ void BuildDygraphPhiKernelContext(
} }
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
......
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/one_hot_v2_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,26 +26,6 @@ namespace operators { ...@@ -22,26 +26,6 @@ namespace operators {
class OneHotV2Op : public framework::OperatorWithKernel { class OneHotV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "one_hot_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "one_hot_v2");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 1,
platform::errors::InvalidArgument(
"Rank of Input(X) should be at least 1."));
int depth = ctx->Attrs().Get<int>("depth");
if (ctx->HasInput("depth_tensor")) {
depth = -1;
}
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -52,7 +36,7 @@ class OneHotV2Op : public framework::OperatorWithKernel { ...@@ -52,7 +36,7 @@ class OneHotV2Op : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "depth_tensor") { if (var_name == "depth_tensor") {
return expected_kernel_type; return expected_kernel_type;
...@@ -114,10 +98,12 @@ Out is a LoDTensor: ...@@ -114,10 +98,12 @@ Out is a LoDTensor:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(one_hot_v2, OneHotInferShapeFunctor,
PD_INFER_META(phi::OneHotRawInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker, one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OP_CPU_KERNEL( OneHotInferShapeFunctor);
one_hot_v2, ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/one_hot_v2_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
class OneHotV2NPUKernel : public framework::OpKernel<T> { class OneHotV2NPUKernel : public framework::OpKernel<T> {
......
...@@ -55,6 +55,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag", ...@@ -55,6 +55,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"expand_grad", "expand_grad",
"expand_as_grad", "expand_as_grad",
"sum", "sum",
"one_hot",
"sum_grad", "sum_grad",
"top_k", "top_k",
"top_k_grad"}); "top_k_grad"});
......
...@@ -72,6 +72,10 @@ void MetaTensor::set_layout(DataLayout layout) { ...@@ -72,6 +72,10 @@ void MetaTensor::set_layout(DataLayout layout) {
} }
void MetaTensor::share_lod(const MetaTensor& meta_tensor) { void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
if (meta_tensor.lod().size() == 0) {
// no need share
return;
}
if (phi::DenseTensor::classof(tensor_)) { if (phi::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod = DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
meta_tensor.lod(); meta_tensor.lod();
......
...@@ -1602,6 +1602,43 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -1602,6 +1602,43 @@ void UnfoldInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims)); out->set_dims(phi::make_ddim(out_dims));
} }
void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(dtype);
}
void OneHotInferMeta(const MetaTensor& x,
const Scalar& depth_t,
MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
int depth = depth_t.to<int>();
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(phi::DataType::FLOAT32);
}
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
auto rank = condition.dims().size(); auto rank = condition.dims().size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
......
...@@ -228,6 +228,14 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -228,6 +228,14 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
MetaTensor* out);
void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out);
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
} // namespace phi } // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,23 +12,25 @@ ...@@ -12,23 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators {
template <typename DeviceContext, typename InT> template <typename DeviceContext, typename InT>
struct OneHotV2OpFunctor { struct OneHotV2OpFunctor {
const framework::LoDTensor* in_; const DenseTensor* in_;
framework::LoDTensor* out_; DenseTensor* out_;
int depth_; int depth_;
const DeviceContext& ctx_; const DeviceContext& ctx_;
bool allow_out_of_range_; bool allow_out_of_range_;
OneHotV2OpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, OneHotV2OpFunctor(const DenseTensor* in,
int depth, const DeviceContext& ctx, DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false) bool allow_out_of_range = false)
: in_(in), : in_(in),
out_(out), out_(out),
...@@ -40,8 +42,8 @@ struct OneHotV2OpFunctor { ...@@ -40,8 +42,8 @@ struct OneHotV2OpFunctor {
void apply() const { void apply() const {
auto* p_in_data = in_->data<InT>(); auto* p_in_data = in_->data<InT>();
auto numel = in_->numel(); auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* p_out_data = ctx_.template Alloc<OutT>(out_);
phi::funcs::set_constant(ctx_, out_, 0.0); funcs::set_constant(ctx_, out_, 0.0);
if (allow_out_of_range_) { if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) { for (int i = 0; i < numel; ++i) {
...@@ -52,51 +54,46 @@ struct OneHotV2OpFunctor { ...@@ -52,51 +54,46 @@ struct OneHotV2OpFunctor {
} else { } else {
for (int i = 0; i < numel; ++i) { for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
p_in_data[i], 0, p_in_data[i],
platform::errors::InvalidArgument( 0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, " "Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0", "but received input (%d) less than 0",
p_in_data[i])); p_in_data[i]));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
p_in_data[i], depth_, p_in_data[i],
platform::errors::InvalidArgument( depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than " "Illegal index value, Input(input) value should be less than "
"Input(depth), " "Input(depth), "
"but received input (%d) not less than depth (%d)", "but received input (%d) not less than depth (%d)",
p_in_data[i], depth_)); p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0; *(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
} }
} }
} }
}; };
using LoDTensor = framework::LoDTensor; template <typename T, typename Context>
using Tensor = framework::Tensor; void OneHotRawKernel(const Context& dev_ctx,
template <typename DeviceContext, typename T> const DenseTensor& x,
class OneHotV2Kernel : public framework::OpKernel<T> { int32_t depth,
public: DataType dtype,
void Compute(const framework::ExecutionContext& context) const override { bool allow_out_of_range,
auto* in = context.Input<LoDTensor>("X"); DenseTensor* out) {
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
bool allow_out_of_range = context.Attr<bool>("allow_out_of_range");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
depth = depth_data[0];
auto out_dims = out->dims(); auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth; out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims); out->Resize(out_dims);
} }
framework::VisitDataType( phi::VisitDataType(dtype,
static_cast<framework::proto::VarType::Type>( OneHotV2OpFunctor<Context, T>(
context.Attr<int>("dtype")), &x, out, depth, dev_ctx, allow_out_of_range));
OneHotV2OpFunctor<DeviceContext, T>( }
in, out, depth, context.template device_context<DeviceContext>(),
allow_out_of_range)); } // namespace phi
}
};
} // namespace operators PD_REGISTER_KERNEL(
} // namespace paddle one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,17 +12,22 @@ ...@@ -12,17 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/one_hot_v2_op.h" #include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace paddle { using paddle::platform::PADDLE_CUDA_NUM_THREADS;
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename InT, typename OutT> template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, __global__ void FillOutputKernel(const InT* p_in_data,
const int64_t numel, const int depth) { OutT* p_out_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
...@@ -31,13 +36,14 @@ __global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, ...@@ -31,13 +36,14 @@ __global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
template <typename DeviceContext, typename InT> template <typename DeviceContext, typename InT>
struct OneHotV2OpCUDAFunctor { struct OneHotV2OpCUDAFunctor {
const framework::LoDTensor* in_; const DenseTensor* in_;
framework::LoDTensor* out_; DenseTensor* out_;
const DeviceContext& ctx_; const DeviceContext& ctx_;
int depth_; int depth_;
OneHotV2OpCUDAFunctor(const framework::LoDTensor* in, OneHotV2OpCUDAFunctor(const DenseTensor* in,
framework::LoDTensor* out, int depth, DenseTensor* out,
int depth,
const DeviceContext& ctx) const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {} : in_(in), out_(out), depth_(depth), ctx_(ctx) {}
...@@ -45,56 +51,36 @@ struct OneHotV2OpCUDAFunctor { ...@@ -45,56 +51,36 @@ struct OneHotV2OpCUDAFunctor {
void apply() const { void apply() const {
auto* p_in_data = in_->data<InT>(); auto* p_in_data = in_->data<InT>();
auto numel = in_->numel(); auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* p_out_data = ctx_.template Alloc<OutT>(out_);
auto stream = ctx_.stream(); auto stream = ctx_.stream();
phi::funcs::set_constant(ctx_, out_, 0.0); funcs::set_constant(ctx_, out_, 0.0);
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS,
p_in_data, p_out_data, numel, depth_); 0,
stream>>>(p_in_data, p_out_data, numel, depth_);
} }
}; };
using LoDTensor = framework::LoDTensor; template <typename T, typename Context>
template <typename DeviceContext, typename T> void OneHotRawKernel(const Context& dev_ctx,
class OneHotV2CUDAKernel : public framework::OpKernel<T> { const DenseTensor& x,
public: int32_t depth,
void Compute(const framework::ExecutionContext& context) const override { DataType dtype,
auto* in = context.Input<LoDTensor>("X"); bool allow_out_of_range,
auto* out = context.Output<LoDTensor>("Out"); DenseTensor* out) {
int depth = -1;
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<framework::Tensor>("depth_tensor");
if (platform::is_gpu_place(depth_tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*depth_tensor, platform::CPUPlace(),
&temp);
depth = *temp.data<int32_t>();
} else {
depth = *depth_tensor->data<int32_t>();
}
auto out_dims = out->dims(); auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth; out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims); out->Resize(out_dims);
} else {
depth = context.Attr<int>("depth");
}
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotV2OpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
} }
};
} // namespace operators phi::VisitDataType(
} // namespace paddle dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth, dev_ctx));
}
} // namespace phi
namespace ops = paddle::operators; PD_REGISTER_KERNEL(
REGISTER_OP_CUDA_KERNEL( one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
one_hot_v2,
ops::OneHotV2CUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::OneHotV2CUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& num_classes_s,
DenseTensor* out) {
int num_classes = num_classes_s.to<int>();
OneHotRawKernel<T>(
dev_ctx, x, num_classes, phi::DataType::FLOAT32, false, out);
}
} // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& num_classes,
DenseTensor* out);
template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int32_t depth,
DataType dtype,
bool allow_out_of_range,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature OneHotOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("depth_tensor")) {
return KernelSignature("one_hot_raw",
{"X"},
{"depth_tensor", "dtype", "allow_out_of_range"},
{"Out"});
} else {
return KernelSignature("one_hot_raw",
{"X"},
{"depth", "dtype", "allow_out_of_range"},
{"Out"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot);
PD_REGISTER_ARG_MAPPING_FN(one_hot_v2, phi::OneHotOpArgumentMapping);
...@@ -52,6 +52,12 @@ final_state_name_mapping = { ...@@ -52,6 +52,12 @@ final_state_name_mapping = {
"axis1": "axis1", "axis1": "axis1",
"axis2": "axis2", "axis2": "axis2",
"out": "Out", "out": "Out",
},
"one_hot": {
"final_op_name": "final_state_one_hot",
"x": "X",
"num_class": "depth",
"out": "Out",
} }
} }
......
...@@ -22,7 +22,8 @@ import paddle ...@@ -22,7 +22,8 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.framework import Program, program_guard from paddle.framework import _in_eager_mode
from paddle.fluid.framework import Program, program_guard, _test_eager_guard
class TestOneHotOp(OpTest): class TestOneHotOp(OpTest):
...@@ -45,7 +46,7 @@ class TestOneHotOp(OpTest): ...@@ -45,7 +46,7 @@ class TestOneHotOp(OpTest):
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output()
class TestOneHotOp_attr(OpTest): class TestOneHotOp_attr(OpTest):
...@@ -68,7 +69,7 @@ class TestOneHotOp_attr(OpTest): ...@@ -68,7 +69,7 @@ class TestOneHotOp_attr(OpTest):
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output()
class TestOneHotOp_default_dtype(OpTest): class TestOneHotOp_default_dtype(OpTest):
...@@ -91,7 +92,7 @@ class TestOneHotOp_default_dtype(OpTest): ...@@ -91,7 +92,7 @@ class TestOneHotOp_default_dtype(OpTest):
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output()
class TestOneHotOp_default_dtype_attr(OpTest): class TestOneHotOp_default_dtype_attr(OpTest):
...@@ -114,7 +115,7 @@ class TestOneHotOp_default_dtype_attr(OpTest): ...@@ -114,7 +115,7 @@ class TestOneHotOp_default_dtype_attr(OpTest):
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output()
class TestOneHotOp_out_of_range(OpTest): class TestOneHotOp_out_of_range(OpTest):
...@@ -132,7 +133,7 @@ class TestOneHotOp_out_of_range(OpTest): ...@@ -132,7 +133,7 @@ class TestOneHotOp_out_of_range(OpTest):
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_dygraph=False) self.check_output()
class TestOneHotOp_exception(unittest.TestCase): class TestOneHotOp_exception(unittest.TestCase):
...@@ -190,6 +191,12 @@ class TestOneHotOpApi(unittest.TestCase): ...@@ -190,6 +191,12 @@ class TestOneHotOpApi(unittest.TestCase):
one_hot_label = fluid.one_hot( one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth) input=fluid.dygraph.to_variable(label), depth=depth)
one_hot_label = paddle.nn.functional.one_hot(
fluid.dygraph.to_variable(label), depth)
with _test_eager_guard():
one_hot_label = paddle.nn.functional.one_hot(
paddle.to_tensor(label), depth)
def _run(self, depth): def _run(self, depth):
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth) one_hot_label = fluid.one_hot(input=label, depth=depth)
......
...@@ -19,6 +19,7 @@ from ...fluid.layer_helper import LayerHelper ...@@ -19,6 +19,7 @@ from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle import _C_ops from paddle import _C_ops
from paddle import in_dynamic_mode from paddle import in_dynamic_mode
from paddle.framework import _in_eager_mode
__all__ = [] __all__ = []
...@@ -87,6 +88,8 @@ def one_hot(x, num_classes, name=None): ...@@ -87,6 +88,8 @@ def one_hot(x, num_classes, name=None):
""" """
if in_dynamic_mode(): if in_dynamic_mode():
if _in_eager_mode():
return _C_ops.final_state_one_hot(x, num_classes)
return _C_ops.one_hot_v2(x, 'depth', num_classes, 'allow_out_of_range', return _C_ops.one_hot_v2(x, 'depth', num_classes, 'allow_out_of_range',
False) False)
else: else:
......
...@@ -204,6 +204,15 @@ ...@@ -204,6 +204,15 @@
output : Tensor output : Tensor
invoke : full_like(x, 0, dtype, place) invoke : full_like(x, 0, dtype, place)
- api : one_hot
args : (Tensor x, Scalar num_classes)
output : Tensor
infer_meta :
func : OneHotInferMeta
kernel :
func : one_hot
- api : digamma - api : digamma
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册