From 1c29196e8de08edc18dbfc6c77ebcd22e595e1fd Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 7 Mar 2022 16:43:29 +0800 Subject: [PATCH] [Phi]Move bincount OP to phi (#39947) * move bincount OP to phi * fix dtype * set_dtype by weights or x * fix conflicts --- paddle/fluid/operators/bincount_op.cc | 62 ++------ paddle/fluid/operators/bincount_op.cu | 162 --------------------- paddle/fluid/operators/bincount_op.h | 109 -------------- paddle/phi/infermeta/binary.cc | 50 +++++++ paddle/phi/infermeta/binary.h | 4 + paddle/phi/kernels/bincount_kernel.h | 28 ++++ paddle/phi/kernels/cpu/bincount_kernel.cc | 106 ++++++++++++++ paddle/phi/kernels/gpu/bincount_kernel.cu | 164 ++++++++++++++++++++++ paddle/phi/ops/compat/bincount_sig.cc | 25 ++++ 9 files changed, 386 insertions(+), 324 deletions(-) delete mode 100644 paddle/fluid/operators/bincount_op.cu delete mode 100644 paddle/fluid/operators/bincount_op.h create mode 100644 paddle/phi/kernels/bincount_kernel.h create mode 100644 paddle/phi/kernels/cpu/bincount_kernel.cc create mode 100644 paddle/phi/kernels/gpu/bincount_kernel.cu create mode 100644 paddle/phi/ops/compat/bincount_sig.cc diff --git a/paddle/fluid/operators/bincount_op.cc b/paddle/fluid/operators/bincount_op.cc index b37334a14ba..062e7d510d5 100644 --- a/paddle/fluid/operators/bincount_op.cc +++ b/paddle/fluid/operators/bincount_op.cc @@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/bincount_op.h" - #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -28,51 +31,6 @@ class BincountOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of BincountOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of BincountOp should not be null.")); - - auto input_dim = ctx->GetInputDim("X"); - auto minlength = ctx->Attrs().Get("minlength"); - - PADDLE_ENFORCE_GE(minlength, 0, - platform::errors::InvalidArgument( - "The minlength should be greater than or equal to 0." - "But received minlength is %d", - minlength)); - - PADDLE_ENFORCE_EQ(input_dim.size(), 1, - platform::errors::InvalidArgument( - "The 'shape' of Input(X) must be 1-D tensor." - "But the dimension of Input(X) is [%d]", - input_dim.size())); - - if (ctx->HasInput("Weights")) { - auto weights_dim = ctx->GetInputDim("Weights"); - PADDLE_ENFORCE_EQ(weights_dim.size(), 1, - platform::errors::InvalidArgument( - "The 'shape' of Input(Weights) must be 1-D tensor." - "But the dimension of Input(Weights) is [%d]", - weights_dim.size())); - - PADDLE_ENFORCE_EQ( - weights_dim[0], input_dim[0], - platform::errors::InvalidArgument( - "The 'shape' of Input(Weights) must be equal to the 'shape' of " - "Input(X)." - "But received: the 'shape' of Input(Weights) is [%s]," - "the 'shape' of Input(X) is [%s]", - weights_dim, input_dim)); - } - - ctx->SetOutputDim("Out", phi::make_ddim({-1})); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto data_type = @@ -105,12 +63,10 @@ class BincountOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(bincount, BincountInferShapeFunctor, + PD_INFER_META(phi::BincountInferMeta)); REGISTER_OPERATOR( bincount, ops::BincountOp, ops::BincountOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - bincount, ops::BincountKernel, - ops::BincountKernel, - ops::BincountKernel, - ops::BincountKernel); + paddle::framework::EmptyGradOpMaker, + BincountInferShapeFunctor); diff --git a/paddle/fluid/operators/bincount_op.cu b/paddle/fluid/operators/bincount_op.cu deleted file mode 100644 index cc576d0af92..00000000000 --- a/paddle/fluid/operators/bincount_op.cu +++ /dev/null @@ -1,162 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/bincount_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using platform::PADDLE_CUDA_NUM_THREADS; - -inline int GET_BLOCKS(const int N) { - return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; -} - -template -__global__ void KernelBincount(const InputT* input, const int total_elements, - const bool has_weights, const T* weights, - OutT* output) { - if (!has_weights) { - for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { - paddle::platform::CudaAtomicAdd(&output[input[i]], 1L); - } - } else { - for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { - paddle::platform::CudaAtomicAdd(&output[input[i]], - static_cast(weights[i])); - } - } -} - -template -void BincountCUDAInner(const framework::ExecutionContext& context) { - const Tensor* input = context.Input("X"); - const Tensor* weights = context.Input("Weights"); - Tensor* output = context.Output("Out"); - auto& minlength = context.Attr("minlength"); - - const InputT* input_data = input->data(); - - const int input_numel = input->numel(); - - if (input_data == nullptr) { - framework::DDim out_dim{0}; - output->Resize(out_dim); - output->mutable_data(context.GetPlace()); - return; - } - auto input_x = framework::EigenVector::Flatten(*input); - - framework::Tensor input_min_t, input_max_t; - auto* input_max_data = - input_max_t.mutable_data({1}, context.GetPlace()); - auto* input_min_data = - input_min_t.mutable_data({1}, context.GetPlace()); - - auto input_max_scala = framework::EigenScalar::From(input_max_t); - auto input_min_scala = framework::EigenScalar::From(input_min_t); - - auto* place = context.template device_context().eigen_device(); - input_max_scala.device(*place) = input_x.maximum(); - input_min_scala.device(*place) = input_x.minimum(); - - Tensor input_min_cpu, input_max_cpu; - paddle::framework::TensorCopySync(input_max_t, platform::CPUPlace(), - &input_max_cpu); - paddle::framework::TensorCopySync(input_min_t, platform::CPUPlace(), - &input_min_cpu); - - InputT input_min = input_min_cpu.data()[0]; - - PADDLE_ENFORCE_GE( - input_min, static_cast(0), - platform::errors::InvalidArgument( - "The elements in input tensor must be non-negative ints")); - - int64_t output_size = - static_cast(input_max_cpu.data()[0]) + 1L; - - output_size = std::max(output_size, static_cast(minlength)); - framework::DDim out_dim{output_size}; - output->Resize(out_dim); - - bool has_weights = (weights != nullptr); - - const T* weights_data = has_weights ? weights->data() : nullptr; - - auto stream = - context.template device_context().stream(); - - if (!has_weights) { - int64_t* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, 0L); - - KernelBincount<<>>( - input_data, input_numel, has_weights, weights_data, output_data); - } else { - const auto& weights_type = framework::TransToProtoVarType(weights->dtype()); - - if (weights_type == framework::proto::VarType::FP32) { - float* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - - KernelBincount<<>>( - input_data, input_numel, has_weights, weights_data, output_data); - } else { - double* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - - KernelBincount<<>>( - input_data, input_numel, has_weights, weights_data, output_data); - } - } -} - -template -class BincountCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const auto& input_type = framework::TransToProtoVarType(input->dtype()); - - if (input_type == framework::proto::VarType::INT32) { - BincountCUDAInner(context); - } else if (input_type == framework::proto::VarType::INT64) { - BincountCUDAInner(context); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - bincount, ops::BincountCUDAKernel, - ops::BincountCUDAKernel, - ops::BincountCUDAKernel, - ops::BincountCUDAKernel); diff --git a/paddle/fluid/operators/bincount_op.h b/paddle/fluid/operators/bincount_op.h deleted file mode 100644 index 84256bf78e4..00000000000 --- a/paddle/fluid/operators/bincount_op.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -void BincountInner(const framework::ExecutionContext& context) { - const Tensor* input = context.Input("X"); - const Tensor* weights = context.Input("Weights"); - Tensor* output = context.Output("Out"); - auto& minlength = context.Attr("minlength"); - - const InputT* input_data = input->data(); - - auto input_numel = input->numel(); - - if (input_data == nullptr) { - framework::DDim out_dim{0}; - output->Resize(out_dim); - output->mutable_data(context.GetPlace()); - return; - } - - PADDLE_ENFORCE_GE( - *std::min_element(input_data, input_data + input_numel), - static_cast(0), - platform::errors::InvalidArgument( - "The elements in input tensor must be non-negative ints")); - - int64_t output_size = static_cast(*std::max_element( - input_data, input_data + input_numel)) + - 1L; - output_size = std::max(output_size, static_cast(minlength)); - - framework::DDim out_dim{output_size}; - output->Resize(out_dim); - - bool has_weights = (weights != nullptr); - - if (has_weights) { - const T* weights_data = weights->data(); - const auto& weights_type = framework::TransToProtoVarType(weights->dtype()); - if (weights_type == framework::proto::VarType::FP32) { - float* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - for (int64_t i = 0; i < input_numel; i++) { - output_data[input_data[i]] += static_cast(weights_data[i]); - } - } else { - double* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - for (int64_t i = 0; i < input_numel; i++) { - output_data[input_data[i]] += static_cast(weights_data[i]); - } - } - - } else { - int64_t* output_data = output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant()( - context.template device_context(), output, 0L); - for (int64_t i = 0; i < input_numel; i++) { - output_data[input_data[i]] += 1L; - } - } -} - -template -class BincountKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const auto& input_type = framework::TransToProtoVarType(input->dtype()); - - if (input_type == framework::proto::VarType::INT32) { - BincountInner(context); - } else if (input_type == framework::proto::VarType::INT64) { - BincountInner(context); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 94b489906c6..55230aa8d05 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -456,6 +456,56 @@ void BCELossInferMeta(const MetaTensor& input, out->share_lod(input); } +void BincountInferMeta(const MetaTensor& x, + const paddle::optional weights, + int minlength, + MetaTensor* out) { + auto input_dim = x.dims(); + + PADDLE_ENFORCE_GE(minlength, + 0, + phi::errors::InvalidArgument( + "The minlength should be greater than or equal to 0." + "But received minlength is %d", + minlength)); + + PADDLE_ENFORCE_EQ( + input_dim.size(), + 1, + phi::errors::InvalidArgument("The 'shape' of Input(X) must be 1-D tensor." + "But the dimension of Input(X) is [%d]", + input_dim.size())); + + if (weights.is_initialized()) { + auto weights_dim = weights->dims(); + PADDLE_ENFORCE_EQ(weights_dim.size(), + 1, + phi::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be 1-D tensor." + "But the dimension of Input(Weights) is [%d]", + weights_dim.size())); + + PADDLE_ENFORCE_EQ( + weights_dim[0], + input_dim[0], + phi::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be equal to the 'shape' of " + "Input(X)." + "But received: the 'shape' of Input(Weights) is [%s]," + "the 'shape' of Input(X) is [%s]", + weights_dim, + input_dim)); + } + out->set_dims(phi::make_ddim({-1})); + if (weights.is_initialized()) { + out->set_dtype(weights->dtype()); + } else { + out->set_dtype(x.dtype()); + } + + out->share_lod(x); +} + void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index caf9185c900..106c22f7548 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -85,6 +85,10 @@ void BCELossInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void BincountInferMeta(const MetaTensor& x, + const paddle::optional weights, + int minlength, + MetaTensor* out); void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/paddle/phi/kernels/bincount_kernel.h b/paddle/phi/kernels/bincount_kernel.h new file mode 100644 index 00000000000..3ba69d36548 --- /dev/null +++ b/paddle/phi/kernels/bincount_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BincountKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional weights, + int minlength, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bincount_kernel.cc b/paddle/phi/kernels/cpu/bincount_kernel.cc new file mode 100644 index 00000000000..c9dc44c1e04 --- /dev/null +++ b/paddle/phi/kernels/cpu/bincount_kernel.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bincount_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void BincountInner(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional weights, + int minlength, + DenseTensor* out) { + const DenseTensor* input = &x; + DenseTensor* output = out; + const InputT* input_data = input->data(); + + auto input_numel = input->numel(); + + if (input_data == nullptr) { + phi::DDim out_dim{0}; + output->Resize(out_dim); + dev_ctx.template Alloc(output); + return; + } + + PADDLE_ENFORCE_GE( + *std::min_element(input_data, input_data + input_numel), + static_cast(0), + phi::errors::InvalidArgument( + "The elements in input tensor must be non-negative ints")); + + int64_t output_size = static_cast(*std::max_element( + input_data, input_data + input_numel)) + + 1L; + output_size = std::max(output_size, static_cast(minlength)); + + phi::DDim out_dim{output_size}; + output->Resize(out_dim); + + bool has_weights = weights.is_initialized(); + + if (has_weights) { + const T* weights_data = weights->data(); + if (weights->dtype() == DataType::FLOAT32) { + float* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += static_cast(weights_data[i]); + } + } else { + double* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += static_cast(weights_data[i]); + } + } + + } else { + int64_t* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, 0L); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += 1L; + } + } +} + +template +void BincountKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional weights, + int minlength, + DenseTensor* out) { + if (x.dtype() == DataType::INT32) { + BincountInner(dev_ctx, x, weights, minlength, out); + } else if (x.dtype() == DataType::INT64) { + BincountInner(dev_ctx, x, weights, minlength, out); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(bincount, + CPU, + ALL_LAYOUT, + phi::BincountKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu new file mode 100644 index 00000000000..a4ec894790c --- /dev/null +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -0,0 +1,164 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bincount_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelBincount(const InputT* input, + const int total_elements, + const bool has_weights, + const T* weights, + OutT* output) { + if (!has_weights) { + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[input[i]], 1L); + } + } else { + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[input[i]], + static_cast(weights[i])); + } + } +} + +template +void BincountCUDAInner(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional weights, + int minlength, + DenseTensor* out) { + const DenseTensor* input = &x; + DenseTensor* output = out; + const InputT* input_data = input->data(); + + const int input_numel = input->numel(); + + if (input_data == nullptr) { + phi::DDim out_dim{0}; + output->Resize(out_dim); + dev_ctx.template Alloc(output); + return; + } + auto input_x = EigenVector::Flatten(*input); + DenseTensor input_min_t, input_max_t; + input_max_t.Resize({1}); + auto* input_max_data = dev_ctx.template Alloc(&input_max_t); + input_min_t.Resize({1}); + auto* input_min_data = dev_ctx.template Alloc(&input_min_t); + + auto input_max_scala = EigenScalar::From(input_max_t); + auto input_min_scala = EigenScalar::From(input_min_t); + + auto* place = dev_ctx.eigen_device(); + input_max_scala.device(*place) = input_x.maximum(); + input_min_scala.device(*place) = input_x.minimum(); + + DenseTensor input_min_cpu, input_max_cpu; + paddle::framework::TensorCopySync( + input_max_t, phi::CPUPlace(), &input_max_cpu); + paddle::framework::TensorCopySync( + input_min_t, phi::CPUPlace(), &input_min_cpu); + + InputT input_min = input_min_cpu.data()[0]; + + PADDLE_ENFORCE_GE( + input_min, + static_cast(0), + phi::errors::InvalidArgument( + "The elements in input tensor must be non-negative ints")); + + int64_t output_size = + static_cast(input_max_cpu.data()[0]) + 1L; + + output_size = std::max(output_size, static_cast(minlength)); + phi::DDim out_dim{output_size}; + output->Resize(out_dim); + + bool has_weights = weights.is_initialized(); + + const T* weights_data = has_weights ? weights->data() : nullptr; + auto stream = dev_ctx.stream(); + + if (!has_weights) { + int64_t* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()(dev_ctx, output, 0L); + + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } else { + const auto& weights_type = + paddle::framework::TransToProtoVarType(weights->dtype()); + + if (weights->dtype() == DataType::FLOAT32) { + float* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } else { + double* output_data = dev_ctx.template Alloc(output); + phi::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } + } +} + +template +void BincountKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional weights, + int minlength, + DenseTensor* out) { + if (x.dtype() == DataType::INT32) { + BincountCUDAInner(dev_ctx, x, weights, minlength, out); + } else if (x.dtype() == DataType::INT64) { + BincountCUDAInner(dev_ctx, x, weights, minlength, out); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(bincount, + GPU, + ALL_LAYOUT, + phi::BincountKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/ops/compat/bincount_sig.cc b/paddle/phi/ops/compat/bincount_sig.cc new file mode 100644 index 00000000000..35067c256ed --- /dev/null +++ b/paddle/phi/ops/compat/bincount_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BincountOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("bincount", {"X", "Weights"}, {"minlength"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(bincount, phi::BincountOpArgumentMapping); -- GitLab