diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 5119c306906915cb3d6f7646a4338b8b6fa24ef7..b1d7059f311cd370a40e83d7b0016d5af8cdb163 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -500,8 +500,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, "Unsupported attribute type is received when call " "InferShapeFunctor.")); } - } else { - // do nothing + } else if (ctx->HasInput(attr_name)) { + // convert from data + if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { + if (ctx->IsRuntime()) { + const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); + auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); + auto val = experimental::MakePhiScalarFromVar(*var_temp); + int32_t val_int = val.template to(); + infer_meta_context.EmplaceBackAttr(val_int); + } else { + infer_meta_context.EmplaceBackAttr(-1); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Get value from variable only support int yet")); + } } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f8e30c1ee294ecf692e2992b6123232ba1c8bd7d..f23a266ef03641bc8f8d273b15ab4982e377cb03 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2250,41 +2250,62 @@ void OperatorWithKernel::BuildPhiKernelContext( } } else { // TODO(chenweihang): support other attrs later - auto& attr = Attrs().at(attr_names[i]); + auto attr_it = attrs_.find(attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(int))) { - pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); + if (attr_it == attrs_.end()) { + auto in_it = ctx.inputs.find(attr_names[i]); + if (in_it != ctx.inputs.end()) { + // get data from input + auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0])); + int32_t val_int = val.template to(); + pt_kernel_context->EmplaceBackAttr(val_int); + } else { + PADDLE_THROW(platform::errors::NotFound( + "can not find attribute `%s` both in attribute and input ", + attr_names[i])); + } + } else { + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(int, attr_it->second)); + } } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { - pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(float, attr_it->second)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { - pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(bool, attr_it->second)); } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { - pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(int64_t, attr_it->second)); } else if (attr_defs[i].type_index == std::type_index(typeid(std::string))) { - pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::string, attr_it->second)); } else if (attr_defs[i].type_index == std::type_index(typeid(phi::DataType))) { auto data_type = paddle::framework::TransToPhiDataType( static_cast( - BOOST_GET_CONST(int, attr))); + BOOST_GET_CONST(int, attr_it->second))); pt_kernel_context->EmplaceBackAttr(data_type); } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { - if (std::type_index(attr.type()) == + if (std::type_index(attr_it->second.type()) == std::type_index(typeid(std::vector))) { pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + BOOST_GET_CONST(std::vector, attr_it->second)); + } else if (std::type_index(attr_it->second.type()) == std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const auto& vector_int_attr = + BOOST_GET_CONST(std::vector, attr_it->second); const std::vector vector_int64_attr(vector_int_attr.begin(), vector_int_attr.end()); pt_kernel_context->EmplaceBackAttr(vector_int64_attr); } } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const auto& vector_int_attr = + BOOST_GET_CONST(std::vector, attr_it->second); pt_kernel_context->EmplaceBackAttr(vector_int_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 91e6974fa2edd1996dafa567ce9f2279d7cc4569..8deb3b93e9c50489dcfc6805063f23e3705cb634 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -419,6 +419,17 @@ void BuildDygraphPhiKernelContext( experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); } + } else if (ins.find(attr_names[i]) != ins.end()) { + // deal tensor attr here + auto& ins_vector = ins.at(attr_names[i]); + auto tensor_attr = + experimental::MakePhiScalarFromVar(ins_vector[0]->Var()); + if (attr_defs[i].type_index == std::type_index(typeid(int))) { + int val = tensor_attr.template to(); + kernel_ctx->EmplaceBackAttr(val); + } else { + PADDLE_THROW(platform::errors::Unimplemented("only support int here")); + } } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); @@ -475,6 +486,7 @@ void BuildDygraphPhiKernelContext( } } else { // TODO(chenweihang): support other attrs later + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(int))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index e212f4e7e2b7d1ad7964cc9351f1c4e241d5a79e..122b6a8a80aac95ab98ad95ed3e6339684978d12 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/one_hot_v2_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -22,26 +26,6 @@ namespace operators { class OneHotV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "one_hot_v2"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "one_hot_v2"); - - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE(x_dims.size(), 1, - platform::errors::InvalidArgument( - "Rank of Input(X) should be at least 1.")); - - int depth = ctx->Attrs().Get("depth"); - if (ctx->HasInput("depth_tensor")) { - depth = -1; - } - - auto out_dims_vec = phi::vectorize(x_dims); - out_dims_vec.push_back(depth); - auto out_dims = phi::make_ddim(out_dims_vec); - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /* --> */ "Out"); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -52,7 +36,7 @@ class OneHotV2Op : public framework::OperatorWithKernel { } framework::OpKernelType GetKernelTypeForVar( - const std::string& var_name, const Tensor& tensor, + const std::string& var_name, const framework::Tensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { if (var_name == "depth_tensor") { return expected_kernel_type; @@ -114,10 +98,12 @@ Out is a LoDTensor: } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(one_hot_v2, OneHotInferShapeFunctor, + PD_INFER_META(phi::OneHotRawInferMeta)); + REGISTER_OPERATOR( one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - one_hot_v2, ops::OneHotV2Kernel, - ops::OneHotV2Kernel); + paddle::framework::EmptyGradOpMaker, + OneHotInferShapeFunctor); diff --git a/paddle/fluid/operators/one_hot_v2_op.cu b/paddle/fluid/operators/one_hot_v2_op.cu deleted file mode 100644 index 77e2a931e50de5b7775463fc7bbf6262e2ad4a53..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/one_hot_v2_op.cu +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/one_hot_v2_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, - const int64_t numel, const int depth) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { - *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; - } -} - -template -struct OneHotV2OpCUDAFunctor { - const framework::LoDTensor* in_; - framework::LoDTensor* out_; - const DeviceContext& ctx_; - int depth_; - - OneHotV2OpCUDAFunctor(const framework::LoDTensor* in, - framework::LoDTensor* out, int depth, - const DeviceContext& ctx) - : in_(in), out_(out), depth_(depth), ctx_(ctx) {} - - template - void apply() const { - auto* p_in_data = in_->data(); - auto numel = in_->numel(); - auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); - auto stream = ctx_.stream(); - phi::funcs::set_constant(ctx_, out_, 0.0); - - FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - p_in_data, p_out_data, numel, depth_); - } -}; - -using LoDTensor = framework::LoDTensor; -template -class OneHotV2CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - - int depth = -1; - if (context.HasInput("depth_tensor")) { - auto* depth_tensor = context.Input("depth_tensor"); - if (platform::is_gpu_place(depth_tensor->place())) { - framework::Tensor temp; - paddle::framework::TensorCopySync(*depth_tensor, platform::CPUPlace(), - &temp); - depth = *temp.data(); - } else { - depth = *depth_tensor->data(); - } - - auto out_dims = out->dims(); - out_dims[out_dims.size() - 1] = depth; - out->Resize(out_dims); - } else { - depth = context.Attr("depth"); - } - framework::VisitDataType( - static_cast( - context.Attr("dtype")), - OneHotV2OpCUDAFunctor( - in, out, depth, context.template device_context())); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - one_hot_v2, - ops::OneHotV2CUDAKernel, - ops::OneHotV2CUDAKernel); diff --git a/paddle/fluid/operators/one_hot_v2_op_npu.cc b/paddle/fluid/operators/one_hot_v2_op_npu.cc index acf6baf50b418ae0fd68d64f52f80f47df1c60c3..e5702a37bb2b4a4180e209bb5e306be64830bd99 100644 --- a/paddle/fluid/operators/one_hot_v2_op_npu.cc +++ b/paddle/fluid/operators/one_hot_v2_op_npu.cc @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/one_hot_v2_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; template class OneHotV2NPUKernel : public framework::OpKernel { diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index 00e9bff9bd5910ceedcca3dfb3a7a64ec88596df..7f4384545f353ecdbd33c73751e186061bf316cc 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -55,6 +55,7 @@ const std::unordered_set deprecated_op_names({"diag", "expand_grad", "expand_as_grad", "sum", + "one_hot", "sum_grad", "top_k", "top_k_grad"}); diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 38a6e09a61ef83aa313a67a5de1ee21ce16038eb..bcbb1a4835b9d0397f6e85b7c44311bb9fe57209 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -72,6 +72,10 @@ void MetaTensor::set_layout(DataLayout layout) { } void MetaTensor::share_lod(const MetaTensor& meta_tensor) { + if (meta_tensor.lod().size() == 0) { + // no need share + return; + } if (phi::DenseTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_))->lod = meta_tensor.lod(); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d09a2191fb2d664b15de4904d3abe30b1091286b..4d1cb42bd59f072e5926b237528a742c231bcdcf 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1602,6 +1602,43 @@ void UnfoldInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(out_dims)); } +void OneHotRawInferMeta(const MetaTensor& x, + int32_t depth, + DataType dtype, + bool allow_out_of_range, + MetaTensor* out) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), + 1, + phi::errors::InvalidArgument("Rank of Input(X) should be at least 1.")); + + auto out_dims_vec = phi::vectorize(x_dims); + out_dims_vec.push_back(depth); + auto out_dims = phi::make_ddim(out_dims_vec); + out->set_dims(out_dims); + out->share_lod(x); + out->set_dtype(dtype); +} + +void OneHotInferMeta(const MetaTensor& x, + const Scalar& depth_t, + MetaTensor* out) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), + 1, + phi::errors::InvalidArgument("Rank of Input(X) should be at least 1.")); + + int depth = depth_t.to(); + auto out_dims_vec = phi::vectorize(x_dims); + out_dims_vec.push_back(depth); + auto out_dims = phi::make_ddim(out_dims_vec); + out->set_dims(out_dims); + out->share_lod(x); + out->set_dtype(phi::DataType::FLOAT32); +} + void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { auto rank = condition.dims().size(); PADDLE_ENFORCE_GE( diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a1fc6fd4053d7c27af37811c28540553ea5c1d7c..75fb9fadf82dc87ac18814e0674e5012fea95ec4 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -228,6 +228,14 @@ void UnfoldInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void OneHotRawInferMeta(const MetaTensor& x, + int32_t depth, + DataType dtype, + bool allow_out_of_range, + MetaTensor* out); + +void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); + void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); } // namespace phi diff --git a/paddle/fluid/operators/one_hot_v2_op.h b/paddle/phi/kernels/cpu/one_hot_kernel.cc similarity index 50% rename from paddle/fluid/operators/one_hot_v2_op.h rename to paddle/phi/kernels/cpu/one_hot_kernel.cc index 9d42c5875bb6eecd1244ca1a0dd6442985ec2a02..dc58489ebf70eaaa7efba52775f7ac62bb2ef5b2 100644 --- a/paddle/fluid/operators/one_hot_v2_op.h +++ b/paddle/phi/kernels/cpu/one_hot_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,23 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/one_hot_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { +namespace phi { template struct OneHotV2OpFunctor { - const framework::LoDTensor* in_; - framework::LoDTensor* out_; + const DenseTensor* in_; + DenseTensor* out_; int depth_; const DeviceContext& ctx_; bool allow_out_of_range_; - OneHotV2OpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, - int depth, const DeviceContext& ctx, + OneHotV2OpFunctor(const DenseTensor* in, + DenseTensor* out, + int depth, + const DeviceContext& ctx, bool allow_out_of_range = false) : in_(in), out_(out), @@ -40,8 +42,8 @@ struct OneHotV2OpFunctor { void apply() const { auto* p_in_data = in_->data(); auto numel = in_->numel(); - auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); - phi::funcs::set_constant(ctx_, out_, 0.0); + auto* p_out_data = ctx_.template Alloc(out_); + funcs::set_constant(ctx_, out_, 0.0); if (allow_out_of_range_) { for (int i = 0; i < numel; ++i) { @@ -52,51 +54,46 @@ struct OneHotV2OpFunctor { } else { for (int i = 0; i < numel; ++i) { PADDLE_ENFORCE_GE( - p_in_data[i], 0, - platform::errors::InvalidArgument( + p_in_data[i], + 0, + phi::errors::InvalidArgument( "Illegal index value, Input(input) value should be at least 0, " "but received input (%d) less than 0", p_in_data[i])); PADDLE_ENFORCE_LT( - p_in_data[i], depth_, - platform::errors::InvalidArgument( + p_in_data[i], + depth_, + phi::errors::InvalidArgument( "Illegal index value, Input(input) value should be less than " "Input(depth), " "but received input (%d) not less than depth (%d)", - p_in_data[i], depth_)); + p_in_data[i], + depth_)); *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; } } } }; -using LoDTensor = framework::LoDTensor; -using Tensor = framework::Tensor; -template -class OneHotV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - int depth = context.Attr("depth"); - bool allow_out_of_range = context.Attr("allow_out_of_range"); - if (context.HasInput("depth_tensor")) { - auto* depth_tensor = context.Input("depth_tensor"); - auto* depth_data = depth_tensor->data(); - depth = depth_data[0]; - auto out_dims = out->dims(); - out_dims[out_dims.size() - 1] = depth; - out->Resize(out_dims); - } - - framework::VisitDataType( - static_cast( - context.Attr("dtype")), - OneHotV2OpFunctor( - in, out, depth, context.template device_context(), - allow_out_of_range)); +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + int32_t depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out) { + auto out_dims = out->dims(); + if (out_dims[out_dims.size() - 1] == -1) { + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); } -}; -} // namespace operators -} // namespace paddle + phi::VisitDataType(dtype, + OneHotV2OpFunctor( + &x, out, depth, dev_ctx, allow_out_of_range)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/one_hot_kernel.cu b/paddle/phi/kernels/gpu/one_hot_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..32c7fa1e85d150b99e7a05d169b01cd8727c1a98 --- /dev/null +++ b/paddle/phi/kernels/gpu/one_hot_kernel.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/one_hot_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillOutputKernel(const InT* p_in_data, + OutT* p_out_data, + const int64_t numel, + const int depth) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { + *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; + } +} + +template +struct OneHotV2OpCUDAFunctor { + const DenseTensor* in_; + DenseTensor* out_; + const DeviceContext& ctx_; + int depth_; + + OneHotV2OpCUDAFunctor(const DenseTensor* in, + DenseTensor* out, + int depth, + const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = ctx_.template Alloc(out_); + auto stream = ctx_.stream(); + funcs::set_constant(ctx_, out_, 0.0); + + FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(p_in_data, p_out_data, numel, depth_); + } +}; + +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + int32_t depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out) { + auto out_dims = out->dims(); + if (out_dims[out_dims.size() - 1] == -1) { + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); + } + + phi::VisitDataType( + dtype, OneHotV2OpCUDAFunctor(&x, out, depth, dev_ctx)); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/one_hot_kernel.cc b/paddle/phi/kernels/one_hot_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..633f48cbb62ace9e3f7f21502bd61f8c305fb542 --- /dev/null +++ b/paddle/phi/kernels/one_hot_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/one_hot_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void OneHotKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& num_classes_s, + DenseTensor* out) { + int num_classes = num_classes_s.to(); + OneHotRawKernel( + dev_ctx, x, num_classes, phi::DataType::FLOAT32, false, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {} +#endif diff --git a/paddle/phi/kernels/one_hot_kernel.h b/paddle/phi/kernels/one_hot_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9f89609ea63365b0e7831201ca003d6c7320c5d7 --- /dev/null +++ b/paddle/phi/kernels/one_hot_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void OneHotKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& num_classes, + DenseTensor* out); + +template +void OneHotRawKernel(const Context& dev_ctx, + const DenseTensor& x, + int32_t depth, + DataType dtype, + bool allow_out_of_range, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/one_hot_sig.cc b/paddle/phi/ops/compat/one_hot_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..655969093c889aa32ae780f1de3c9c7c81a78eb1 --- /dev/null +++ b/paddle/phi/ops/compat/one_hot_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature OneHotOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("depth_tensor")) { + return KernelSignature("one_hot_raw", + {"X"}, + {"depth_tensor", "dtype", "allow_out_of_range"}, + {"Out"}); + } else { + return KernelSignature("one_hot_raw", + {"X"}, + {"depth", "dtype", "allow_out_of_range"}, + {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot); + +PD_REGISTER_ARG_MAPPING_FN(one_hot_v2, phi::OneHotOpArgumentMapping); diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index a7dd938a1cfe71e0dbddd2ac689a38fe78db748b..d0552ca41f0daf56ce23317dd06cb5744baaff84 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -52,6 +52,12 @@ final_state_name_mapping = { "axis1": "axis1", "axis2": "axis2", "out": "Out", + }, + "one_hot": { + "final_op_name": "final_state_one_hot", + "x": "X", + "num_class": "depth", + "out": "Out", } } diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py index 66de1b309797fb53316a46b436e5cccf11410216..fac258192112dbff5353c581ad8e276cc5e375c0 100644 --- a/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py @@ -22,7 +22,8 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.framework as framework -from paddle.fluid.framework import Program, program_guard +from paddle.framework import _in_eager_mode +from paddle.fluid.framework import Program, program_guard, _test_eager_guard class TestOneHotOp(OpTest): @@ -45,7 +46,7 @@ class TestOneHotOp(OpTest): self.outputs = {'Out': (out, x_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output() class TestOneHotOp_attr(OpTest): @@ -68,7 +69,7 @@ class TestOneHotOp_attr(OpTest): self.outputs = {'Out': (out, x_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output() class TestOneHotOp_default_dtype(OpTest): @@ -91,7 +92,7 @@ class TestOneHotOp_default_dtype(OpTest): self.outputs = {'Out': (out, x_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output() class TestOneHotOp_default_dtype_attr(OpTest): @@ -114,7 +115,7 @@ class TestOneHotOp_default_dtype_attr(OpTest): self.outputs = {'Out': (out, x_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output() class TestOneHotOp_out_of_range(OpTest): @@ -132,7 +133,7 @@ class TestOneHotOp_out_of_range(OpTest): self.outputs = {'Out': (out, x_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output() class TestOneHotOp_exception(unittest.TestCase): @@ -190,6 +191,12 @@ class TestOneHotOpApi(unittest.TestCase): one_hot_label = fluid.one_hot( input=fluid.dygraph.to_variable(label), depth=depth) + one_hot_label = paddle.nn.functional.one_hot( + fluid.dygraph.to_variable(label), depth) + with _test_eager_guard(): + one_hot_label = paddle.nn.functional.one_hot( + paddle.to_tensor(label), depth) + def _run(self, depth): label = fluid.layers.data(name="label", shape=[1], dtype="int64") one_hot_label = fluid.one_hot(input=label, depth=depth) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index de8a7ff6d3c7b6cd87d6301f2cd0bb7af119a74d..4c30ed03735f26b6df77c6a8f5b32391972738e5 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -19,6 +19,7 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle import _C_ops from paddle import in_dynamic_mode +from paddle.framework import _in_eager_mode __all__ = [] @@ -87,6 +88,8 @@ def one_hot(x, num_classes, name=None): """ if in_dynamic_mode(): + if _in_eager_mode(): + return _C_ops.final_state_one_hot(x, num_classes) return _C_ops.one_hot_v2(x, 'depth', num_classes, 'allow_out_of_range', False) else: diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 639afeb4c86faa9a3e18870cff07da2eb535cceb..0d012685b738c63f3eadc0c1d8c44592fe875845 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -204,6 +204,15 @@ output : Tensor invoke : full_like(x, 0, dtype, place) + +- api : one_hot + args : (Tensor x, Scalar num_classes) + output : Tensor + infer_meta : + func : OneHotInferMeta + kernel : + func : one_hot + - api : digamma args : (Tensor x) output : Tensor