From a07f19eedadc49570744a04bfd8a3492518328e7 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Thu, 10 Mar 2022 13:27:28 +0800 Subject: [PATCH] [PHI] Move segment_pool to phi. (#40099) * move segment_pool to phi. * mark summed ids as optional tensor. * fix as reviews. --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 1 - paddle/fluid/operators/segment_pool_op.cc | 37 +-- paddle/fluid/operators/segment_pool_op.cu | 27 -- paddle/fluid/operators/segment_pool_op.h | 176 ----------- paddle/fluid/operators/unity_build_rule.cmake | 4 +- paddle/phi/infermeta/binary.cc | 19 ++ paddle/phi/infermeta/binary.h | 8 + paddle/phi/kernels/CMakeLists.txt | 4 +- .../kernels/cpu/segment_pool_grad_kernel.cc | 26 ++ paddle/phi/kernels/cpu/segment_pool_kernel.cc | 22 ++ paddle/phi/kernels/funcs/CMakeLists.txt | 1 + .../kernels/funcs}/segment_pooling.cc | 84 ++--- .../kernels/funcs}/segment_pooling.cu | 289 +++++++++++------- .../kernels/funcs}/segment_pooling.h | 35 ++- .../kernels/gpu/segment_pool_grad_kernel.cu | 27 ++ paddle/phi/kernels/gpu/segment_pool_kernel.cu | 23 ++ .../impl/segment_pool_grad_kernel_impl.h | 51 ++++ .../kernels/impl/segment_pool_kernel_impl.h | 142 +++++++++ paddle/phi/kernels/segment_pool_grad_kernel.h | 31 ++ paddle/phi/kernels/segment_pool_kernel.h | 29 ++ paddle/phi/ops/compat/segment_pool_sig.cc | 33 ++ 22 files changed, 666 insertions(+), 405 deletions(-) delete mode 100644 paddle/fluid/operators/segment_pool_op.cu delete mode 100644 paddle/fluid/operators/segment_pool_op.h create mode 100644 paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/segment_pool_kernel.cc rename paddle/{fluid/operators/math => phi/kernels/funcs}/segment_pooling.cc (65%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/segment_pooling.cu (54%) rename paddle/{fluid/operators/math => phi/kernels/funcs}/segment_pooling.h (51%) create mode 100644 paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/segment_pool_kernel.cu create mode 100644 paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/segment_pool_kernel_impl.h create mode 100644 paddle/phi/kernels/segment_pool_grad_kernel.h create mode 100644 paddle/phi/kernels/segment_pool_kernel.h create mode 100644 paddle/phi/ops/compat/segment_pool_sig.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 91a0352e191..e77be832c0c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function lod_tensor maxouting unpooling pooling lod_rank_table context_project -sequence_pooling segment_pooling executor device_memory_aligment generator) +sequence_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d5a86d62b41..31a98d9f630 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -46,7 +46,6 @@ math_library(vol2col) math_library(prelu) math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) -math_library(segment_pooling) math_library(matrix_solve) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/segment_pool_op.cc b/paddle/fluid/operators/segment_pool_op.cc index 322cd97f01c..9d4c8532a82 100644 --- a/paddle/fluid/operators/segment_pool_op.cc +++ b/paddle/fluid/operators/segment_pool_op.cc @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/segment_pool_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool"); - OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds", - "SegmentPool"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool"); - auto dims = ctx->GetInputDim("X"); - dims[0] = -1; - ctx->SetOutputDim("Out", dims); - - if (ctx->Attrs().Get("pooltype") == "MEAN") { - OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds", - "SegmentPool"); - ctx->SetOutputDim("SummedIds", {-1, 1}); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(segment_pool, SegmentPoolInferShapeFunctor, + PD_INFER_META(phi::SegmentPoolInferMeta)); + REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker, ops::SegmentPoolGradOpMaker, - ops::SegmentPoolGradOpMaker); + ops::SegmentPoolGradOpMaker, + SegmentPoolInferShapeFunctor); REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp); - -REGISTER_OP_CPU_KERNEL( - segment_pool, - ops::SegmentPoolKernel, - ops::SegmentPoolKernel); - -REGISTER_OP_CPU_KERNEL( - segment_pool_grad, - ops::SegmentPoolGradKernel, - ops::SegmentPoolGradKernel); diff --git a/paddle/fluid/operators/segment_pool_op.cu b/paddle/fluid/operators/segment_pool_op.cu deleted file mode 100644 index e147e62a983..00000000000 --- a/paddle/fluid/operators/segment_pool_op.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/segment_pool_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - segment_pool, - ops::SegmentPoolKernel, - ops::SegmentPoolKernel); -REGISTER_OP_CUDA_KERNEL( - segment_pool_grad, - ops::SegmentPoolGradKernel, - ops::SegmentPoolGradKernel); diff --git a/paddle/fluid/operators/segment_pool_op.h b/paddle/fluid/operators/segment_pool_op.h deleted file mode 100644 index 2f5ef7f54f9..00000000000 --- a/paddle/fluid/operators/segment_pool_op.h +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/segment_pooling.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) { - auto* input = context.Input("X"); - auto* segment = context.Input("SegmentIds"); - auto* output = context.Output("Out"); - std::string pooltype = context.Attr("pooltype"); - Tensor* summed_ids = nullptr; - - int64_t num_indices = segment->numel(); - PADDLE_ENFORCE_EQ( - num_indices, input->dims()[0], - platform::errors::InvalidArgument( - "Segment_ids should be the same size as dimension 0 of input X.")); - PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0], - platform::errors::InvalidArgument( - "Segment_ids should be 1-D tensor, or it's other " - "dimension size is 1. Segment_ids's shape is: [%s].", - segment->dims())); - - if (input->numel() == 0 || segment->numel() == 0) { - return; - } - - bool cpu_place = context.GetPlace().GetType() == phi::AllocationType::CPU; - if (cpu_place) { - auto dims = input->dims(); - auto* segment_ids = segment->data(); - dims[0] = static_cast(segment_ids[segment->numel() - 1] + 1); - PADDLE_ENFORCE_GT( - dims[0], 0, - platform::errors::InvalidArgument( - "Segment ids must be >= 0, but got last id %d", dims[0])); - output->Resize({dims}); - output->mutable_data(context.GetPlace()); - phi::funcs::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); - set_zero(dev_ctx, output, static_cast(0)); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (!cpu_place) { - Tensor length; - length.mutable_data(phi::make_ddim({1}), platform::CPUPlace()); - IndexT* length_data = length.data(); - const IndexT* segment_ids = segment->data(); - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT), - hipMemcpyDeviceToHost)); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT), - cudaMemcpyDeviceToHost)); -#endif - - IndexT length_host = length_data[0]; - length_host++; - PADDLE_ENFORCE_GT( - length_host, 0, - platform::errors::InvalidArgument( - "Segment ids must be >= 0, but got last id %d", length_data[0])); - auto dims = input->dims(); - dims[0] = static_cast(length_host); - output->Resize({dims}); - output->mutable_data(context.GetPlace()); - T init_value = 0; - if (pooltype == "MAX") { - init_value = static_cast(-FLT_MAX); - } else if (pooltype == "MIN") { - init_value = static_cast(FLT_MAX); - } - phi::funcs::SetConstant setconst; - auto& dev_ctx = context.template device_context(); - setconst(dev_ctx, output, static_cast(init_value)); - // the gpu kernel of mean pool record the counts of segment_ids - if (pooltype == "MEAN") { - summed_ids = context.Output("SummedIds"); - summed_ids->Resize({dims[0], 1}); - summed_ids->mutable_data(context.GetPlace()); - setconst(dev_ctx, summed_ids, static_cast(1e-12)); - } - } -#endif - - SegmentPoolFunctor pool; - - pool(context.template device_context(), *input, *segment, - output, summed_ids, pooltype); -} - -template -class SegmentPoolKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* segment = context.Input("SegmentIds"); - auto index_type = framework::TransToProtoVarType(segment->dtype()); - if (index_type == framework::proto::VarType::INT32) { - SegmentKernelLaunchHelper(context); - } else if (index_type == framework::proto::VarType::INT64) { - SegmentKernelLaunchHelper(context); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported index type, Expected int, int64, but got %s.", - index_type)); - } - } -}; - -template -class SegmentPoolGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Input("Out"); - auto* segment = context.Input("SegmentIds"); - auto* out_g = context.Input(framework::GradVarName("Out")); - auto* in_g = context.Output(framework::GradVarName("X")); - std::string pooltype = context.Attr("pooltype"); - - const Tensor* summed_ids = nullptr; - if (pooltype == "MEAN") { - summed_ids = context.Input("SummedIds"); - } - - in_g->mutable_data(context.GetPlace()); - phi::funcs::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); - set_zero(dev_ctx, in_g, static_cast(0)); - - auto index_type = framework::TransToProtoVarType(segment->dtype()); - if (index_type == framework::proto::VarType::INT32) { - SegmentPoolGradFunctor pool; - pool(context.template device_context(), *input, *output, - *out_g, *segment, in_g, summed_ids, pooltype); - } else if (index_type == framework::proto::VarType::INT64) { - SegmentPoolGradFunctor pool; - pool(context.template device_context(), *input, *output, - *out_g, *segment, in_g, summed_ids, pooltype); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported index type, Expected int, int64, but got %s.", - index_type)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 5ab20046178..1be8f3387db 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -236,7 +236,6 @@ register_unity_group(cc scatter_nd_add_op.cc scatter_op.cc seed_op.cc - segment_pool_op.cc select_input_op.cc select_output_op.cc) register_unity_group(cc @@ -496,8 +495,7 @@ register_unity_group(cu scale_op.cu scatter_nd_add_op.cu scatter_op.cu - seed_op.cu - segment_pool_op.cu) + seed_op.cu) register_unity_group(cu roi_pool_op.cu selu_op.cu diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index b17405990fb..ff73829c475 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->share_meta(x); } +void SegmentPoolInferMeta(const MetaTensor& x, + const MetaTensor& segment_ids, + const std::string& pooltype, + MetaTensor* out, + MetaTensor* summed_ids, + MetaConfig config) { + auto dims = x.dims(); + dims[0] = -1; + out->set_dims(dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + + if (pooltype == "MEAN") { + summed_ids->set_dims({-1, 1}); + summed_ids->set_dtype(x.dtype()); + summed_ids->set_layout(x.layout()); + } +} + void BCELossInferMeta(const MetaTensor& input, const MetaTensor& label, MetaTensor* out, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 934ed688bf2..bc5cb887f2a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x, MetaTensor* out); void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); + +void SegmentPoolInferMeta(const MetaTensor& x, + const MetaTensor& segment_ids, + const std::string& pooltype, + MetaTensor* out, + MetaTensor* summed_ids, + MetaConfig config = MetaConfig()); + void BCELossInferMeta(const MetaTensor& input, const MetaTensor& label, MetaTensor* out, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 71e0d9e3479..9b4b14bf51e 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel) +set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) @@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) +kernel_library(segment_pool_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling) +kernel_library(segment_pool_grad_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling) # 4. auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) diff --git a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc new file mode 100644 index 00000000000..585c27bdcec --- /dev/null +++ b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/segment_pool_grad_kernel.h" +#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(segment_pool_grad, + CPU, + ALL_LAYOUT, + phi::SegmentPoolGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/segment_pool_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_kernel.cc new file mode 100644 index 00000000000..d0413457f81 --- /dev/null +++ b/paddle/phi/kernels/cpu/segment_pool_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/segment_pool_kernel.h" +#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index f0fbb7bf084..e0db7b51f8e 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(lapack) add_subdirectory(detail) math_library(math_function DEPS blas dense_tensor tensor) +math_library(segment_pooling) math_library(sequence2batch) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) diff --git a/paddle/fluid/operators/math/segment_pooling.cc b/paddle/phi/kernels/funcs/segment_pooling.cc similarity index 65% rename from paddle/fluid/operators/math/segment_pooling.cc rename to paddle/phi/kernels/funcs/segment_pooling.cc index d16fc570a9f..bf4a21f3722 100644 --- a/paddle/fluid/operators/math/segment_pooling.cc +++ b/paddle/phi/kernels/funcs/segment_pooling.cc @@ -12,45 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/segment_pooling.h" +#include "paddle/phi/kernels/funcs/segment_pooling.h" #include -#include "paddle/fluid/framework/eigen.h" -namespace paddle { -namespace operators { +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" -using Tensor = framework::Tensor; +namespace phi { +namespace funcs { + +using Tensor = DenseTensor; template -class SegmentPoolFunctor { +class SegmentPoolFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& segments, framework::Tensor* output, - framework::Tensor* index, + void operator()(const phi::CPUContext& dev_ctx, + const DenseTensor& input, + const DenseTensor& segments, + DenseTensor* output, + DenseTensor* index, const std::string pooltype = "SUM") { const IndexT* segment_ids = segments.data(); auto curent_id = segment_ids[0]; int64_t last_idx = 0; int64_t w = input.numel() / input.dims()[0]; - auto& place = *context.eigen_device(); + auto& place = *dev_ctx.eigen_device(); for (int64_t idx = 1; idx <= segments.numel(); ++idx) { if (idx < segments.numel()) { if (segment_ids[idx] == curent_id) continue; - PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_GE(segment_ids[idx], + curent_id, + phi::errors::InvalidArgument( "The segment ids should be sorted, but got " "segment_ids[%d]:%d > segment_ids[%d]:%d.", - idx - 1, curent_id, idx, segment_ids[idx])); + idx - 1, + curent_id, + idx, + segment_ids[idx])); } Tensor out_t = output->Slice(curent_id, curent_id + 1); Tensor in_t = input.Slice(last_idx, idx); int64_t h = idx - last_idx; - auto in_e = framework::EigenMatrix::From(in_t, phi::make_ddim({h, w})); - auto out_e = framework::EigenVector::Flatten(out_t); + auto in_e = EigenMatrix::From(in_t, phi::make_ddim({h, w})); + auto out_e = EigenVector::Flatten(out_t); auto reduce_dim = Eigen::array({{0}}); if (pooltype == "MEAN") { @@ -62,7 +69,7 @@ class SegmentPoolFunctor { } else if (pooltype == "MIN") { out_e.device(place) = in_e.minimum(reduce_dim); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " "available, but got %s.", pooltype)); @@ -75,36 +82,41 @@ class SegmentPoolFunctor { }; template -class SegmentPoolGradFunctor { +class SegmentPoolGradFunctor { public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& out_grad, - const framework::Tensor& segments, framework::Tensor* in_grad, - const framework::Tensor* index = nullptr, + void operator()(const phi::CPUContext& dev_ctx, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& out_grad, + const DenseTensor& segments, + DenseTensor* in_grad, + paddle::optional index, const std::string pooltype = "SUM") { const IndexT* segment_ids = segments.data(); - auto& place = *context.eigen_device(); + auto& place = *dev_ctx.eigen_device(); auto curent_id = segment_ids[0]; int64_t last_idx = 0; int64_t w = in_grad->numel() / in_grad->dims()[0]; for (int64_t idx = 1; idx <= segments.numel(); ++idx) { if (idx < segments.numel()) { if (segment_ids[idx] == curent_id) continue; - PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_GE(segment_ids[idx], + curent_id, + phi::errors::InvalidArgument( "The segment ids should be sorted, but got " "segment_ids[%d]:%d > segment_ids[%d]:%d.", - idx - 1, curent_id, idx, segment_ids[idx])); + idx - 1, + curent_id, + idx, + segment_ids[idx])); } Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1); Tensor in_g_t = in_grad->Slice(last_idx, idx); int64_t h = idx - last_idx; - auto in_g_e = framework::EigenMatrix::From(in_g_t, {h, w}); - auto out_g_e = framework::EigenMatrix::From(out_g_t, {1, w}); + auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); + auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); Eigen::DSizes bcast(h, 1); if (pooltype == "MEAN") { @@ -114,13 +126,13 @@ class SegmentPoolGradFunctor { } else if (pooltype == "MAX" || pooltype == "MIN") { Tensor out_t = output.Slice(curent_id, curent_id + 1); Tensor in_t = input.Slice(last_idx, idx); - auto in_e = framework::EigenMatrix::From(in_t, {h, w}); - auto out_e = framework::EigenMatrix::From(out_t, {1, w}); + auto in_e = EigenMatrix::From(in_t, {h, w}); + auto out_e = EigenMatrix::From(out_t, {1, w}); in_g_e.device(place) = (in_e == out_e.broadcast(bcast)).template cast() * out_g_e.broadcast(bcast); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " "available, but got %s.", pooltype)); @@ -132,7 +144,7 @@ class SegmentPoolGradFunctor { } }; -using CPU = platform::CPUDeviceContext; +using CPU = phi::CPUContext; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; @@ -142,5 +154,5 @@ template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu similarity index 54% rename from paddle/fluid/operators/math/segment_pooling.cu rename to paddle/phi/kernels/funcs/segment_pooling.cu index fbdcb99c02a..305cd39f077 100644 --- a/paddle/fluid/operators/math/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -12,20 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/funcs/segment_pooling.h" + #include -#include "paddle/fluid/operators/math/segment_pooling.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" + #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { -using Tensor = framework::Tensor; +using Tensor = DenseTensor; template -__global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids, +__global__ void SegmentSumIdsKernel(const Index* segment_ids, + T* summed_ids, const Index input_length_size, const Index total_stripe_count) { CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) { @@ -45,16 +49,19 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids, PADDLE_ENFORCE(current_segment_id >= last_segment_id, "the segment ids should be sorted, but got " "segment_ids[%d]:%d > segment_ids[%d]:%d.", - dim_index_base + j - 1, dim_index_base + j, - last_segment_id, current_segment_id); + dim_index_base + j - 1, + dim_index_base + j, + last_segment_id, + current_segment_id); if (current_segment_id > last_segment_id) { for (Index interval_id = last_segment_id + 1; - interval_id < current_segment_id; ++interval_id) { + interval_id < current_segment_id; + ++interval_id) { *(summed_ids + interval_id) = 0; } if (j > 0) { if (last_segment_id == first_segment_id) { - platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); } else { *(summed_ids + last_segment_id) = sum; } @@ -64,13 +71,15 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids, sum += T(1); last_segment_id = current_segment_id; } - platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); } } template -__global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, - T* output, T* summed_ids, +__global__ void SegmentMeanKernel(const Index* segment_ids, + const T* input, + T* output, + T* summed_ids, const Index input_length_size, const Index inner_dim_size, const Index output_length_size, @@ -93,7 +102,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, if (current_segment_id > last_segment_id) { // reset the interval value which do not have corresponding ids. for (Index interval_id = last_segment_id + 1; - interval_id < current_segment_id; ++interval_id) { + interval_id < current_segment_id; + ++interval_id) { *(output + interval_id * inner_dim_size + segment_offset) = T(0); } @@ -102,8 +112,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, last_segment_id * inner_dim_size + segment_offset; if (last_segment_id == first_segment_id) { - platform::CudaAtomicAdd(output + output_index, - sum / *(summed_ids + last_segment_id)); + paddle::platform::CudaAtomicAdd( + output + output_index, sum / *(summed_ids + last_segment_id)); } else { *(output + output_index) = sum / *(summed_ids + last_segment_id); } @@ -114,15 +124,14 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, last_segment_id = current_segment_id; } Index output_index = last_segment_id * inner_dim_size + segment_offset; - platform::CudaAtomicAdd(output + output_index, - sum / *(summed_ids + last_segment_id)); + paddle::platform::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); } } template -__global__ void __launch_bounds__(1024, 1) - SegmentOpsKernel(const Index* segment_ids, const T* input, T* output, - Helper h, Pool pool) { +__global__ void __launch_bounds__(1024, 1) SegmentOpsKernel( + const Index* segment_ids, const T* input, T* output, Helper h, Pool pool) { CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) { Index segment_offset, dim_index_base, actual_height; Index inner_dim_size = h.inner_dim_size; @@ -142,13 +151,16 @@ __global__ void __launch_bounds__(1024, 1) PADDLE_ENFORCE(current_segment_id >= last_segment_id, "The segment ids should be sorted, but got " "segment_ids[%d]:%d > segment_ids[%d]:%d.", - dim_index_base + j - 1, dim_index_base + j, - last_segment_id, current_segment_id); + dim_index_base + j - 1, + dim_index_base + j, + last_segment_id, + current_segment_id); if (current_segment_id > last_segment_id) { // reset the interval value which do not have corresponding ids. for (Index interval_id = last_segment_id + 1; - interval_id < current_segment_id; ++interval_id) { + interval_id < current_segment_id; + ++interval_id) { *(output + interval_id * inner_dim_size + segment_offset) = T(0); } // don't update result when j=0 @@ -175,9 +187,12 @@ __global__ void __launch_bounds__(1024, 1) } template -__global__ void SegmentIndexGradKernel(const Index* segment_ids, const T* input, - const T* output, const T* out_grad, - T* in_grad, Helper h) { +__global__ void SegmentIndexGradKernel(const Index* segment_ids, + const T* input, + const T* output, + const T* out_grad, + T* in_grad, + Helper h) { CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) { Index segment_offset, dim_index_base, actual_height; h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height); @@ -201,7 +216,7 @@ class MaxPool { DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } DEVICE inline T atomic(T* address, const T val) { - return platform::CudaAtomicMax(address, val); + return paddle::platform::CudaAtomicMax(address, val); } }; @@ -211,7 +226,7 @@ class MinPool { DEVICE inline T initial() { return static_cast(FLT_MAX); } DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; } DEVICE inline T atomic(T* address, const T val) { - return platform::CudaAtomicMin(address, val); + return paddle::platform::CudaAtomicMin(address, val); } }; @@ -221,7 +236,7 @@ class SumPool { DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(const T& x, T* y) { *y = *y + x; } DEVICE inline T atomic(T* address, const T val) { - return platform::CudaAtomicAdd(address, val); + return paddle::platform::CudaAtomicAdd(address, val); } }; @@ -243,8 +258,10 @@ class ArrangeHelper { total_stripe_count = inner_dim_size * input_outer_dim_num_stripe; } - DEVICE inline void calculate(T stripe_index, T* segment_offset, - T* dim_index_base, T* actual_height) { + DEVICE inline void calculate(T stripe_index, + T* segment_offset, + T* dim_index_base, + T* actual_height) { *segment_offset = stripe_index % inner_dim_size; *dim_index_base = stripe_index / inner_dim_size * DimTileSize; *actual_height = min(DimTileSize, input_length_size - *dim_index_base); @@ -252,23 +269,32 @@ class ArrangeHelper { }; template -void SegmentPoolCUDAGradFunctor(const platform::CUDADeviceContext& ctx, - const framework::Tensor& input, - const framework::Tensor& segment_ids, - const framework::Tensor& output, - const framework::Tensor& out_grad, - framework::Tensor* in_grad, +void SegmentPoolCUDAGradFunctor(const phi::GPUContext& ctx, + const DenseTensor& input, + const DenseTensor& segment_ids, + const DenseTensor& output, + const DenseTensor& out_grad, + DenseTensor* in_grad, const std::string pooltype = "SUM") { - auto h = ArrangeHelper(input.numel(), segment_ids.dims()[0], - output.dims()[0]); - auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); + auto h = ArrangeHelper( + input.numel(), segment_ids.dims()[0], output.dims()[0]); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); if (pooltype == "MAX" || pooltype == "MIN") { - SegmentIndexGradKernel><<< - config.block_per_grid.x, config.thread_per_block.x, 0, ctx.stream()>>>( - segment_ids.data(), input.data(), output.data(), - out_grad.data(), in_grad->data(), h); + SegmentIndexGradKernel><<>>( + segment_ids.data(), + input.data(), + output.data(), + out_grad.data(), + in_grad->data(), + h); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Unsupported segment pooling grad operation, Only MAX, MIN " "available, but got %s.", pooltype)); @@ -291,13 +317,13 @@ __global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) { } template -class SegmentPoolFunctor { +class SegmentPoolFunctor { public: - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& input, - const framework::Tensor& segment_ids, - framework::Tensor* output, - framework::Tensor* summed_ids = nullptr, + void operator()(const phi::GPUContext& ctx, + const DenseTensor& input, + const DenseTensor& segment_ids, + DenseTensor* output, + DenseTensor* summed_ids = nullptr, const std::string pooltype = "SUM") { if (pooltype == "MEAN") { // Sum the segment id num first @@ -305,50 +331,76 @@ class SegmentPoolFunctor { auto input_length_size = segment_ids.numel(); auto total_stripe_count = (input_length_size + DimTileSize - 1) / DimTileSize; - auto config = platform::GetGpuLaunchConfig1D(ctx, total_stripe_count); - SegmentSumIdsKernel< - T, IndexT, IndexT(8)><<>>( - segment_ids.data(), summed_ids->data(), input_length_size, + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(ctx, total_stripe_count); + SegmentSumIdsKernel<<>>( + segment_ids.data(), + summed_ids->data(), + input_length_size, total_stripe_count); } - auto h = ArrangeHelper(input.numel(), segment_ids.dims()[0], - output->dims()[0]); - auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); + auto h = ArrangeHelper( + input.numel(), segment_ids.dims()[0], output->dims()[0]); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); if (pooltype == "MEAN") { - SegmentMeanKernel< - T, IndexT, IndexT(8)><<>>( - segment_ids.data(), input.data(), output->data(), - summed_ids->data(), h.input_length_size, h.inner_dim_size, - h.output_length_size, h.total_stripe_count); + SegmentMeanKernel<<>>( + segment_ids.data(), + input.data(), + output->data(), + summed_ids->data(), + h.input_length_size, + h.inner_dim_size, + h.output_length_size, + h.total_stripe_count); } else if (pooltype == "SUM") { SumPool pool; - SegmentOpsKernel< - T, IndexT, ArrangeHelper, - SumPool><<>>(segment_ids.data(), - input.data(), output->data(), h, - pool); + SegmentOpsKernel, + SumPool><<>>(segment_ids.data(), + input.data(), + output->data(), + h, + pool); } else if (pooltype == "MAX") { MaxPool pool; - SegmentOpsKernel< - T, IndexT, ArrangeHelper, - MaxPool><<>>(segment_ids.data(), - input.data(), output->data(), h, - pool); + SegmentOpsKernel, + MaxPool><<>>(segment_ids.data(), + input.data(), + output->data(), + h, + pool); } else if (pooltype == "MIN") { MinPool pool; - SegmentOpsKernel< - T, IndexT, ArrangeHelper, - MinPool><<>>(segment_ids.data(), - input.data(), output->data(), h, - pool); + SegmentOpsKernel, + MinPool><<>>(segment_ids.data(), + input.data(), + output->data(), + h, + pool); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " "available, but got %s.", pooltype)); @@ -357,33 +409,38 @@ class SegmentPoolFunctor { }; template -class SegmentPoolGradFunctor { +class SegmentPoolGradFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& out_grad, - const framework::Tensor& segments, framework::Tensor* in_grad, - const framework::Tensor* summed_ids = nullptr, + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& out_grad, + const DenseTensor& segments, + DenseTensor* in_grad, + paddle::optional summed_ids, const std::string pooltype = "SUM") { if (pooltype == "MAX" || pooltype == "MIN") { - SegmentPoolCUDAGradFunctor(context, input, segments, output, - out_grad, in_grad, pooltype); + SegmentPoolCUDAGradFunctor( + dev_ctx, input, segments, output, out_grad, in_grad, pooltype); } else if (pooltype == "MEAN") { - framework::Tensor mean_grad; - mean_grad.mutable_data(input.dims(), context.GetPlace()); - framework::TensorCopy(out_grad, context.GetPlace(), context, &mean_grad); + DenseTensor mean_grad; + mean_grad.Resize(input.dims()); + dev_ctx.template Alloc(&mean_grad); + paddle::framework::TensorCopy( + out_grad, dev_ctx.GetPlace(), dev_ctx, &mean_grad); int len = output.dims()[0]; int dim = output.numel() / len; - auto config = platform::GetGpuLaunchConfig1D(context, len); - SimpleDiv<<>>(mean_grad.data(), - summed_ids->data(), len, dim); - phi::funcs::GPUGather(context, mean_grad, segments, in_grad); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len); + SimpleDiv<<>>( + mean_grad.data(), summed_ids->data(), len, dim); + phi::funcs::GPUGather(dev_ctx, mean_grad, segments, in_grad); } else if (pooltype == "SUM") { - phi::funcs::GPUGather(context, out_grad, segments, in_grad); + phi::funcs::GPUGather(dev_ctx, out_grad, segments, in_grad); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " "available, but got %s.", pooltype)); @@ -391,15 +448,15 @@ class SegmentPoolGradFunctor { } }; -using CUDA = paddle::platform::CUDADeviceContext; -template class SegmentPoolFunctor; -template class SegmentPoolFunctor; -template class SegmentPoolFunctor; -template class SegmentPoolFunctor; -template class SegmentPoolGradFunctor; -template class SegmentPoolGradFunctor; -template class SegmentPoolGradFunctor; -template class SegmentPoolGradFunctor; - -} // namespace operators -} // namespace paddle +using GPU = phi::GPUContext; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/segment_pooling.h b/paddle/phi/kernels/funcs/segment_pooling.h similarity index 51% rename from paddle/fluid/operators/math/segment_pooling.h rename to paddle/phi/kernels/funcs/segment_pooling.h index 561fad6921f..b8281061582 100644 --- a/paddle/fluid/operators/math/segment_pooling.h +++ b/paddle/phi/kernels/funcs/segment_pooling.h @@ -14,33 +14,36 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { -template +template class SegmentPoolFunctor { public: /* mean pool has summed_ids output */ - void operator()(const DeviceContext& context, const framework::Tensor& input, - const framework::Tensor& segments, framework::Tensor* output, - framework::Tensor* summed_ids = nullptr, + void operator()(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& segments, + DenseTensor* output, + DenseTensor* summed_ids = nullptr, const std::string pooltype = "SUM"); }; -template +template class SegmentPoolGradFunctor { public: /* mean pool has summed_ids output */ - void operator()(const DeviceContext& context, const framework::Tensor& input, - const framework::Tensor& output, - const framework::Tensor& out_grad, - const framework::Tensor& segments, framework::Tensor* in_grad, - const framework::Tensor* summed_ids = nullptr, + void operator()(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& out_grad, + const DenseTensor& segments, + DenseTensor* in_grad, + paddle::optional summed_ids, const std::string pooltype = "SUM"); }; -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu new file mode 100644 index 00000000000..d9618dc159a --- /dev/null +++ b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h" +#include "paddle/phi/kernels/segment_pool_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(segment_pool_grad, + GPU, + ALL_LAYOUT, + phi::SegmentPoolGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/segment_pool_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_kernel.cu new file mode 100644 index 00000000000..c38e935adf8 --- /dev/null +++ b/paddle/phi/kernels/gpu/segment_pool_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h" +#include "paddle/phi/kernels/segment_pool_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h new file mode 100644 index 00000000000..4ba1a0c6b6c --- /dev/null +++ b/paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/segment_pooling.h" + +namespace phi { + +template +void SegmentPoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& segment_ids, + const DenseTensor& out, + paddle::optional summed_ids, + const DenseTensor& out_grad, + const std::string& pooltype, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + + auto index_type = segment_ids.type(); + if (index_type == DataType::INT32) { + phi::funcs::SegmentPoolGradFunctor pool; + pool(dev_ctx, x, out, out_grad, segment_ids, x_grad, summed_ids, pooltype); + } else if (index_type == DataType::INT64) { + phi::funcs::SegmentPoolGradFunctor pool; + pool(dev_ctx, x, out, out_grad, segment_ids, x_grad, summed_ids, pooltype); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported index type, Expected int, int64, but got %s.", + index_type)); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/segment_pool_kernel_impl.h b/paddle/phi/kernels/impl/segment_pool_kernel_impl.h new file mode 100644 index 00000000000..8a6df37ab3e --- /dev/null +++ b/paddle/phi/kernels/impl/segment_pool_kernel_impl.h @@ -0,0 +1,142 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/segment_pooling.h" + +namespace phi { + +template +void SegmentKernelLaunchHelper(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& segment_ids, + const std::string& pooltype, + DenseTensor* out, + DenseTensor* summed_ids) { + int64_t num_indices = segment_ids.numel(); + PADDLE_ENFORCE_EQ( + num_indices, + x.dims()[0], + phi::errors::InvalidArgument( + "Segment_ids should be the same size as dimension 0 of input X.")); + PADDLE_ENFORCE_EQ(num_indices, + segment_ids.dims()[0], + phi::errors::InvalidArgument( + "Segment_ids should be 1-D tensor, or it's other " + "dimension size is 1. Segment_ids's shape is: [%s].", + segment_ids.dims())); + + if (x.numel() == 0 || segment_ids.numel() == 0) { + return; + } + + bool cpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU; + if (cpu_place) { + auto dims = x.dims(); + auto* segment_ids_ptr = segment_ids.data(); + dims[0] = + static_cast(segment_ids_ptr[segment_ids.numel() - 1] + 1); + PADDLE_ENFORCE_GT( + dims[0], + 0, + phi::errors::InvalidArgument( + "Segment ids must be >= 0, but got last id %d", dims[0])); + + out->Resize({dims}); + dev_ctx.template Alloc(out); + + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out, static_cast(0)); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (!cpu_place) { + DenseTensor length; + length.Resize(phi::make_ddim({1})); + IndexT* length_data = dev_ctx.template HostAlloc(&length); + + const IndexT* segment_ids_ptr = segment_ids.data(); + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(length_data, + segment_ids_ptr + num_indices - 1, + sizeof(IndexT), + hipMemcpyDeviceToHost)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(length_data, + segment_ids_ptr + num_indices - 1, + sizeof(IndexT), + cudaMemcpyDeviceToHost)); +#endif + + IndexT length_host = length_data[0]; + length_host++; + PADDLE_ENFORCE_GT( + length_host, + 0, + phi::errors::InvalidArgument( + "Segment ids must be >= 0, but got last id %d", length_data[0])); + auto dims = x.dims(); + dims[0] = static_cast(length_host); + out->Resize({dims}); + dev_ctx.template Alloc(out); + + T init_value = 0; + if (pooltype == "MAX") { + init_value = static_cast(-FLT_MAX); + } else if (pooltype == "MIN") { + init_value = static_cast(FLT_MAX); + } + phi::funcs::SetConstant setconst; + setconst(dev_ctx, out, static_cast(init_value)); + // the gpu kernel of mean pool record the counts of segment_ids + if (pooltype == "MEAN") { + summed_ids->Resize({dims[0], 1}); + dev_ctx.template Alloc(summed_ids); + setconst(dev_ctx, summed_ids, static_cast(1e-12)); + } + } +#endif + + phi::funcs::SegmentPoolFunctor pool; + + pool(dev_ctx, x, segment_ids, out, summed_ids, pooltype); +} + +template +void SegmentPoolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& segment_ids, + const std::string& pooltype, + DenseTensor* out, + DenseTensor* summed_ids) { + auto index_type = segment_ids.dtype(); + if (index_type == DataType::INT32) { + SegmentKernelLaunchHelper( + dev_ctx, x, segment_ids, pooltype, out, summed_ids); + } else if (index_type == DataType::INT64) { + SegmentKernelLaunchHelper( + dev_ctx, x, segment_ids, pooltype, out, summed_ids); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported index type, Expected int, int64, but got %s.", + index_type)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/segment_pool_grad_kernel.h b/paddle/phi/kernels/segment_pool_grad_kernel.h new file mode 100644 index 00000000000..e773eed16e8 --- /dev/null +++ b/paddle/phi/kernels/segment_pool_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SegmentPoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& segment_ids, + const DenseTensor& out, + paddle::optional summed_ids, + const DenseTensor& out_grad, + const std::string& pooltype, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/segment_pool_kernel.h b/paddle/phi/kernels/segment_pool_kernel.h new file mode 100644 index 00000000000..8f7b30c2e86 --- /dev/null +++ b/paddle/phi/kernels/segment_pool_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SegmentPoolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& segment_ids, + const std::string& pooltype, + DenseTensor* out, + DenseTensor* summed_ids); + +} // namespace phi diff --git a/paddle/phi/ops/compat/segment_pool_sig.cc b/paddle/phi/ops/compat/segment_pool_sig.cc new file mode 100644 index 00000000000..97646a2ac31 --- /dev/null +++ b/paddle/phi/ops/compat/segment_pool_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SegmentPoolGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "segment_pool_grad", + { + "X", "SegmentIds", "Out", "SummedIds", GradVarName("Out"), + }, + {"pooltype"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(segment_pool_grad, + phi::SegmentPoolGradOpArgumentMapping); -- GitLab