diff --git a/.gitignore b/.gitignore index 5018bf56c1633237b98d29a66eb86aed41fa6891..ce0cd3bc27b6225a8e6e24a8331022e6224603ac 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,5 @@ paddle/infrt/dialect/pd_ops_info.h .lit_test_times.txt paddle/infrt/tests/dialect/Output paddle/infrt/tests/lit.cfg.py +paddle/fluid/pybind/eager_final_state_op_function_impl.h +paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4670f043102d917f770b6fa5ca661a860941df33..b6d8ca4aa67cbfc3dd34c0a4ef68d2c1bdb7ed94 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2165,6 +2165,8 @@ void OperatorWithKernel::BuildPtenKernelContext( pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(std::string))) { pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index d5dc53196dd7f1abe854785e0e5c1ccd363d1c3f..465fc2fca138ef06f057c69eae2a3419136c1e72 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -421,6 +421,8 @@ void BuildDygraphPtenKernelContext( kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(std::string))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); diff --git a/paddle/fluid/operators/histogram_op.cc b/paddle/fluid/operators/histogram_op.cc index 32cc38ef1953364266181598f44ccd54e9dc631c..2df6b539ff68aa4934dc2562792a55a58b670417 100644 --- a/paddle/fluid/operators/histogram_op.cc +++ b/paddle/fluid/operators/histogram_op.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/histogram_op.h" - #include #include #include +#include "paddle/fluid/framework/op_registry.h" + namespace paddle { namespace operators { @@ -85,8 +85,3 @@ REGISTER_OPERATOR( histogram, ops::HistogramOp, ops::HistogramOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - histogram, ops::HistogramKernel, - ops::HistogramKernel, - ops::HistogramKernel, - ops::HistogramKernel); diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu deleted file mode 100644 index 48a637e6c37b1cf37e5653397ded01775eb54551..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/histogram_op.cu +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/histogram_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/pten/core/hostdevice.h" - -namespace paddle { -namespace operators { - -using IndexType = int64_t; -using Tensor = framework::Tensor; -using platform::PADDLE_CUDA_NUM_THREADS; - -inline int GET_BLOCKS(const int N) { - return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; -} - -template -__device__ static IndexType GetBin(T input_value, T min_value, T max_value, - int64_t nbins) { - IndexType bin = static_cast((input_value - min_value) * nbins / - (max_value - min_value)); - IndexType output_index = bin < nbins - 1 ? bin : nbins - 1; - return output_index; -} - -template -__global__ void KernelHistogram(const T* input, const int total_elements, - const int64_t nbins, const T min_value, - const T max_value, int64_t* output) { - extern __shared__ int64_t buf_hist[]; - for (int i = threadIdx.x; i < nbins; i += blockDim.x) { - buf_hist[i] = 0; - } - __syncthreads(); - - CUDA_KERNEL_LOOP(input_index, total_elements) { - // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; - const auto input_value = input[input_index]; - if (input_value >= min_value && input_value <= max_value) { - const IndexType output_index = - GetBin(input_value, min_value, max_value, nbins); - paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1); - } - } - __syncthreads(); - - for (int i = threadIdx.x; i < nbins; i += blockDim.x) { - paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]); - } -} - -template -class HistogramCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); - - const Tensor* input = context.Input("X"); - Tensor* output = context.Output("Out"); - auto& nbins = context.Attr("bins"); - auto& minval = context.Attr("min"); - auto& maxval = context.Attr("max"); - - const T* input_data = input->data(); - const int input_numel = input->numel(); - - int64_t* out_data = output->mutable_data(context.GetPlace()); - pten::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - - if (input_data == nullptr) return; - - T output_min = static_cast(minval); - T output_max = static_cast(maxval); - - if (output_min == output_max) { - auto input_x = framework::EigenVector::Flatten(*input); - - framework::Tensor input_min_t, input_max_t; - auto* input_min_data = - input_min_t.mutable_data({1}, context.GetPlace()); - auto* input_max_data = - input_max_t.mutable_data({1}, context.GetPlace()); - auto input_min_scala = framework::EigenScalar::From(input_min_t); - auto input_max_scala = framework::EigenScalar::From(input_max_t); - - auto* place = - context.template device_context().eigen_device(); - input_min_scala.device(*place) = input_x.minimum(); - input_max_scala.device(*place) = input_x.maximum(); - - Tensor input_min_cpu, input_max_cpu; - paddle::framework::TensorCopySync(input_min_t, platform::CPUPlace(), - &input_min_cpu); - paddle::framework::TensorCopySync(input_max_t, platform::CPUPlace(), - &input_max_cpu); - - output_min = input_min_cpu.data()[0]; - output_max = input_max_cpu.data()[0]; - } - if (output_min == output_max) { - output_min = output_min - 1; - output_max = output_max + 1; - } - - PADDLE_ENFORCE_EQ( - (std::isinf(static_cast(output_min)) || - std::isnan(static_cast(output_max)) || - std::isinf(static_cast(output_min)) || - std::isnan(static_cast(output_max))), - false, platform::errors::OutOfRange("range of min, max is not finite")); - PADDLE_ENFORCE_GE( - output_max, output_min, - platform::errors::InvalidArgument( - "max must be larger or equal to min. If min and max are both zero, " - "the minimum and maximum values of the data are used. " - "But received max is %d, min is %d", - maxval, minval)); - - auto stream = - context.template device_context().stream(); - KernelHistogram< - T, IndexType><<>>( - input_data, input_numel, nbins, output_min, output_max, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - histogram, - ops::HistogramCUDAKernel, - ops::HistogramCUDAKernel, - ops::HistogramCUDAKernel, - ops::HistogramCUDAKernel); diff --git a/paddle/fluid/operators/histogram_op.h b/paddle/fluid/operators/histogram_op.h deleted file mode 100644 index 9e280336e492af97d0107062f2d2a5ef22191133..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/histogram_op.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/pten/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class HistogramKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - Tensor* output = context.Output("Out"); - auto& nbins = context.Attr("bins"); - auto& minval = context.Attr("min"); - auto& maxval = context.Attr("max"); - - const T* input_data = input->data(); - auto input_numel = input->numel(); - - int64_t* out_data = output->mutable_data(context.GetPlace()); - pten::funcs::SetConstant()( - context.template device_context(), output, - static_cast(0)); - - if (input_data == nullptr) return; - - T output_min = static_cast(minval); - T output_max = static_cast(maxval); - if (output_min == output_max) { - output_min = *std::min_element(input_data, input_data + input_numel); - output_max = *std::max_element(input_data, input_data + input_numel); - } - if (output_min == output_max) { - output_min = output_min - 1; - output_max = output_max + 1; - } - - PADDLE_ENFORCE_EQ( - (std::isinf(static_cast(output_min)) || - std::isnan(static_cast(output_max)) || - std::isinf(static_cast(output_min)) || - std::isnan(static_cast(output_max))), - false, platform::errors::OutOfRange("range of min, max is not finite")); - PADDLE_ENFORCE_GE( - output_max, output_min, - platform::errors::InvalidArgument( - "max must be larger or equal to min. If min and max are both zero, " - "the minimum and maximum values of the data are used. " - "But received max is %d, min is %d", - maxval, minval)); - - for (int64_t i = 0; i < input_numel; i++) { - if (input_data[i] >= output_min && input_data[i] <= output_max) { - const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / - (output_max - output_min)); - out_data[std::min(bin, nbins - 1)] += 1; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/pten/kernels/cpu/histogram_kernel.cc b/paddle/pten/kernels/cpu/histogram_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..700b7e092919aa8d922b0ebfbe8388eb646aac5b --- /dev/null +++ b/paddle/pten/kernels/cpu/histogram_kernel.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/histogram_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/funcs/math_function.h" + +namespace pten { + +template +void HistogramKernel(const Context& dev_ctx, + const DenseTensor& input, + int64_t bins, + int min, + int max, + DenseTensor* output) { + auto& nbins = bins; + auto& minval = min; + auto& maxval = max; + + const T* input_data = input.data(); + auto input_numel = input.numel(); + + int64_t* out_data = output->mutable_data(dev_ctx.GetPlace()); + pten::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + + if (input_data == nullptr) return; + + T output_min = static_cast(minval); + T output_max = static_cast(maxval); + if (output_min == output_max) { + output_min = *std::min_element(input_data, input_data + input_numel); + output_max = *std::max_element(input_data, input_data + input_numel); + } + if (output_min == output_max) { + output_min = output_min - 1; + output_max = output_max + 1; + } + + PADDLE_ENFORCE_EQ( + (std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max)) || + std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max))), + false, + pten::errors::OutOfRange("range of min, max is not finite")); + PADDLE_ENFORCE_GE( + output_max, + output_min, + pten::errors::InvalidArgument( + "max must be larger or equal to min. If min and max are both zero, " + "the minimum and maximum values of the data are used. " + "But received max is %d, min is %d", + maxval, + minval)); + + for (int64_t i = 0; i < input_numel; i++) { + if (input_data[i] >= output_min && input_data[i] <= output_max) { + const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / + (output_max - output_min)); + out_data[std::min(bin, nbins - 1)] += 1; + } + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(histogram, + CPU, + ALL_LAYOUT, + pten::HistogramKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/histogram_kernel.cu b/paddle/pten/kernels/gpu/histogram_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0d0da49e01aebbff375015ddfd7bc90309f9e4d8 --- /dev/null +++ b/paddle/pten/kernels/gpu/histogram_kernel.cu @@ -0,0 +1,160 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/funcs/math_function.h" +#include "paddle/pten/kernels/histogram_kernel.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" + +namespace pten { + +using IndexType = int64_t; +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__device__ static IndexType GetBin(T input_value, + T min_value, + T max_value, + int64_t nbins) { + IndexType bin = static_cast((input_value - min_value) * nbins / + (max_value - min_value)); + IndexType output_index = bin < nbins - 1 ? bin : nbins - 1; + return output_index; +} + +template +__global__ void KernelHistogram(const T* input, + const int total_elements, + const int64_t nbins, + const T min_value, + const T max_value, + int64_t* output) { + extern __shared__ int64_t buf_hist[]; + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + buf_hist[i] = 0; + } + __syncthreads(); + + CUDA_KERNEL_LOOP(input_index, total_elements) { + // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; + const auto input_value = input[input_index]; + if (input_value >= min_value && input_value <= max_value) { + const IndexType output_index = + GetBin(input_value, min_value, max_value, nbins); + paddle::platform::CudaAtomicAdd(&buf_hist[output_index], 1); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < nbins; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]); + } +} + +template +void HistogramKernel(const Context& dev_ctx, + const DenseTensor& input, + int64_t bins, + int min, + int max, + DenseTensor* output) { + auto& nbins = bins; + auto& minval = min; + auto& maxval = max; + + const T* input_data = input.data(); + const int input_numel = input.numel(); + + int64_t* out_data = output->mutable_data(dev_ctx.GetPlace()); + pten::funcs::SetConstant()( + dev_ctx, output, static_cast(0)); + + if (input_data == nullptr) return; + + T output_min = static_cast(minval); + T output_max = static_cast(maxval); + + if (output_min == output_max) { + auto input_x = pten::EigenVector::Flatten(input); + + DenseTensor input_min_t, input_max_t; + auto* input_min_data = input_min_t.mutable_data({1}, dev_ctx.GetPlace()); + auto* input_max_data = input_max_t.mutable_data({1}, dev_ctx.GetPlace()); + auto input_min_scala = pten::EigenScalar::From(input_min_t); + auto input_max_scala = pten::EigenScalar::From(input_max_t); + + auto* place = dev_ctx.eigen_device(); + input_min_scala.device(*place) = input_x.minimum(); + input_max_scala.device(*place) = input_x.maximum(); + + DenseTensor input_min_cpu, input_max_cpu; + paddle::framework::TensorCopySync( + input_min_t, paddle::platform::CPUPlace(), &input_min_cpu); + paddle::framework::TensorCopySync( + input_max_t, paddle::platform::CPUPlace(), &input_max_cpu); + + output_min = input_min_cpu.data()[0]; + output_max = input_max_cpu.data()[0]; + } + if (output_min == output_max) { + output_min = output_min - 1; + output_max = output_max + 1; + } + + PADDLE_ENFORCE_EQ( + (std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max)) || + std::isinf(static_cast(output_min)) || + std::isnan(static_cast(output_max))), + false, + pten::errors::OutOfRange("range of min, max is not finite")); + PADDLE_ENFORCE_GE( + output_max, + output_min, + pten::errors::InvalidArgument( + "max must be larger or equal to min. If min and max are both zero, " + "the minimum and maximum values of the data are used. " + "But received max is %d, min is %d", + maxval, + minval)); + + auto stream = dev_ctx.stream(); + KernelHistogram<<>>( + input_data, input_numel, nbins, output_min, output_max, out_data); +} + +} // namespace pten + +PT_REGISTER_KERNEL(histogram, + GPU, + ALL_LAYOUT, + pten::HistogramKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/histogram_kernel.h b/paddle/pten/kernels/histogram_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4bc4ef6fb9e4657305f4f967371711a0aaabb035 --- /dev/null +++ b/paddle/pten/kernels/histogram_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +namespace pten { + +template +void HistogramSelectKernel(const Context& dev_ctx, + const DenseTensor& input, + int64_t bins, + int min, + int max, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/ops/compat/histogram_sig.cc b/paddle/pten/ops/compat/histogram_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..9849c998d779e46bb955f0bc98686c247fc99b18 --- /dev/null +++ b/paddle/pten/ops/compat/histogram_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature HistogramOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("histogram", {"X"}, {"bins", "min", "max"}, {"Out"}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(histogram, pten::HistogramOpArgumentMapping);