diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 42cb92db8625e334401a8baff38e5ec9f8cda521..405661f3a1ab9a7d7477d7072307040c2dc6d78e 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -15,7 +15,6 @@ math_library(sampler DEPS generator) # math_library(math_function DEPS blas dense_tensor tensor) -math_library(sequence_pooling DEPS math_function jit_kernel_helper) if(WITH_XPU) math_library(beam_search DEPS math_function beam_search_xpu) else() diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 938b23a22a63c93427c27c1d4b99eca91cac376c..c44427f98f211b0cf66f72901018bdde2c49b864 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -196,9 +196,6 @@ REGISTER_OPERATOR(sequence_pool, REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, ops::SequencePoolGradOpNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(sequence_pool, - ops::SequencePoolKernel, - ops::SequencePoolKernel); REGISTER_OP_CPU_KERNEL(sequence_pool_grad, ops::SequencePoolGradKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu index 882ec66f501db0036ba5d2d26bb5e5b0dd9e7dff..df5dde79274f92c28d66c2dd441a0eb8e2896b62 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu @@ -14,7 +14,5 @@ limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(sequence_pool, - ops::SequencePoolKernel); REGISTER_OP_CUDA_KERNEL(sequence_pool_grad, ops::SequencePoolGradKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h index ddf0d496a77fb51580d28b2f7eb889d90eb2b12a..7bf2d384508a707b7dddf3e042b9db9e324e18ab 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h @@ -17,81 +17,12 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" namespace paddle { namespace operators { -template -class SequencePoolKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - std::string pooltype = context.Attr("pooltype"); - T pad_value = static_cast(context.Attr("pad_value")); - - auto dims = in->dims(); - auto lod = in->lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT(lod_level, - 0, - platform::errors::InvalidArgument( - "Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE(lod_level, - 2UL, - platform::errors::InvalidArgument( - "The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - platform::errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1], - lod[1].size() - 1, - platform::errors::InvalidArgument( - "The input lod information is illegal.")); - framework::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - out->mutable_data(context.GetPlace()); - phi::DenseTensor* index = nullptr; - - bool is_test = - context.HasAttr("is_test") ? context.Attr("is_test") : false; - - // Do not create index buffer for inference mode - if (pooltype == "MAX" && - (is_test == false || - platform::is_cpu_place(context.GetPlace()) == false)) { - index = context.Output("MaxIndex"); - index->Resize({dims}); - index->mutable_data(context.GetPlace()); - } - math::SequencePoolFunctor pool; - pool(context.template device_context(), - pooltype, - pad_value, - *in, - out, - is_test, - index); - } -}; - template class SequencePoolGradKernel : public framework::OpKernel { public: @@ -105,7 +36,7 @@ class SequencePoolGradKernel : public framework::OpKernel { index = context.Input("MaxIndex"); } in_g->mutable_data(context.GetPlace()); - math::SequencePoolGradFunctor pool; + phi::funcs::SequencePoolGradFunctor pool; pool(context.template device_context(), pooltype, *out_g, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index b0da4df9c8da98745ef837e87ffc7d9c43c74f61..fe8548f9fb5fc3aa3099560e0f5e25d32cdb0387 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -74,6 +74,7 @@ set(COMMON_KERNEL_DEPS phi_dynload_warpctc phi_dynload_warprnnt sequence_padding + sequence_pooling sequence_scale fft phi_data_layout_transform diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4bc0e03983d5bf6f6baaf34aeb999d3e54aebc98 --- /dev/null +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" + +PD_REGISTER_KERNEL( + sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 20e97cb887b26716d0b064b85cd38b3a4c1de318..bd1774d756c4bdbd980ca4c68f6fd13e183317fe 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -25,6 +25,7 @@ math_library(maxouting) math_library(matrix_bit_code) math_library(sequence_scale) math_library(sequence_padding DEPS lod_utils) +math_library(sequence_pooling DEPS math_function jit_kernel_helper) cc_library( phi_data_layout_transform diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/phi/kernels/funcs/sequence_pooling.cc similarity index 91% rename from paddle/fluid/operators/math/sequence_pooling.cc rename to paddle/phi/kernels/funcs/sequence_pooling.cc index 8dbeff2bce1350221ac4f6b326a6de0fab188df5..b530a81ada363a6b98b5dca96f80507358a48c04 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/phi/kernels/funcs/sequence_pooling.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence_pooling.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" #include @@ -21,9 +21,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/jit/kernels.h" #include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template dims(); PADDLE_ENFORCE_GT(in_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of input shall be greater than 1, but got " "the rank is %ld. Please check the input value", in_dims.size())); PADDLE_ENFORCE_GT(out_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of output shall be greater than 1, but got " "the rank is %ld. Please check the input value", out_dims.size())); @@ -61,7 +60,7 @@ class MaxSeqPoolFunctor { PADDLE_ENFORCE_EQ( in_dims[i], out_dims[i], - platform::errors::InvalidArgument( + errors::InvalidArgument( "The dimension of input and output shall be same. Expected %ld " "== %ld, but got %ld != %ld. Please check the input value.", in_dims[i], @@ -72,7 +71,7 @@ class MaxSeqPoolFunctor { PADDLE_ENFORCE_EQ( idx_dims, out_dims, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The dimension of index and output shall be same. Expected %ld == " "%ld, but got %ld != %ld. Please check the input value.", idx_dims, @@ -125,13 +124,13 @@ class MaxSeqPoolFunctor { auto out_dims = output->dims(); PADDLE_ENFORCE_GT(in_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of input shall be greater than 1, but got " "%ld <= 1. Please check the input value.", in_dims.size())); PADDLE_ENFORCE_GT(out_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of output shall be greater than 1, but got " "%ld <= 1. Please check the input value.", out_dims.size())); @@ -139,7 +138,7 @@ class MaxSeqPoolFunctor { PADDLE_ENFORCE_EQ( in_dims[i], out_dims[i], - platform::errors::InvalidArgument( + errors::InvalidArgument( "The dimension of input and output shall be same. Expected %ld " "== %ld, but got %ld != %ld. Please check the input value.", in_dims[i], @@ -186,20 +185,20 @@ class MaxSeqPoolGradFunctor { auto idx_dims = index.dims(); PADDLE_ENFORCE_GT(og_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of output@Grad shall be greater than 1, " "but got %ld <= 1. Please check the input value.", og_dims.size())); PADDLE_ENFORCE_GT(ig_dims.size(), 1, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The rank of input@Grad shall be greater than 1, but " "got %ld <= 1. Please check the input value.", ig_dims.size())); for (int64_t i = 1; i < og_dims.size(); ++i) { PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i], - platform::errors::InvalidArgument( + errors::InvalidArgument( "The dimension of input@Grad and output@Grad shall " "be same. Expected %ld == %ld, but got %ld != %ld. " "Please check the input value.", @@ -211,7 +210,7 @@ class MaxSeqPoolGradFunctor { PADDLE_ENFORCE_EQ( idx_dims, og_dims, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The dimension of index and output@Grad shall be same. Expected " "%ld == %ld, but got %ld != %ld. Please check the input value.", idx_dims, @@ -317,7 +316,7 @@ class SumSeqPoolGradFunctor { int64_t in_w = in_grad->numel() / in_grad->dims()[0]; PADDLE_ENFORCE_EQ(in_w, out_w, - platform::errors::InvalidArgument( + errors::InvalidArgument( "The feature size of input@Grad and output@Grad " "shall be same. Expected %ld == %ld, but got %ld != " "%ld. Please check the input value.", @@ -326,7 +325,7 @@ class SumSeqPoolGradFunctor { in_w, out_w)); const T* out_g_data = out_grad.data(); - T* in_g_data = in_grad->mutable_data(context.GetPlace()); + T* in_g_data = context.template Alloc(in_grad); auto blas = phi::funcs::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); @@ -354,21 +353,21 @@ class SequencePoolFunctor { phi::DenseTensor* index = nullptr) { if (pooltype == "MAX") { if (is_test) { - math::MaxSeqPoolFunctor max_pool; + phi::funcs::MaxSeqPoolFunctor max_pool; max_pool(context, input, pad_value, output, index); } else { - math::MaxSeqPoolFunctor max_pool; + phi::funcs::MaxSeqPoolFunctor max_pool; max_pool(context, input, pad_value, output, index); } return; } if (pooltype == "LAST") { - math::LastSeqPoolFunctor last_pool; + phi::funcs::LastSeqPoolFunctor last_pool; last_pool(context, input, pad_value, output); return; } if (pooltype == "FIRST") { - math::FirstSeqPoolFunctor first_pool; + phi::funcs::FirstSeqPoolFunctor first_pool; first_pool(context, input, pad_value, output); return; } @@ -377,17 +376,17 @@ class SequencePoolFunctor { if (pooltype == "SUM") { auto place = context.GetPlace(); PADDLE_ENFORCE_EQ( - platform::is_cpu_place(place), + place == phi::CPUPlace(), true, - platform::errors::InvalidArgument( + errors::InvalidArgument( "Sequence_pool should run on CPU Device when pooltype is SUM")); const T* src = input.data(); - T* dst = output->mutable_data(place); + T* dst = context.template Alloc(output); phi::jit::seq_pool_attr_t attr( static_cast(input.numel() / input.dims()[0]), phi::jit::SeqPoolType::kSum); auto seqpool = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() + phi::CPUPlace>::Cache() .At(attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); @@ -424,7 +423,7 @@ class SequencePoolFunctor { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(errors::InvalidArgument( "unsupported pooling pooltype: %s. Only support \"AVERAGE\" and " "\"SQRT\"", pooltype)); @@ -443,7 +442,7 @@ class SequencePoolGradFunctor { /* max pool has index */ const phi::DenseTensor* index = nullptr) { if (pooltype == "MAX") { - math::MaxSeqPoolGradFunctor max_pool_grad; + phi::funcs::MaxSeqPoolGradFunctor max_pool_grad; max_pool_grad(context, out_grad, *index, in_grad); return; } @@ -455,7 +454,7 @@ class SequencePoolGradFunctor { } if (pooltype == "SUM") { - math::SumSeqPoolGradFunctor sum_pool_grad; + phi::funcs::SumSeqPoolGradFunctor sum_pool_grad; sum_pool_grad(context, out_grad, in_grad); return; } @@ -485,7 +484,7 @@ class SequencePoolGradFunctor { } else if (pooltype == "FIRST") { in_g_e.chip(0, 0).device(place) = out_g_e_v; } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(errors::InvalidArgument( "unsupported pooling pooltype: %s. Only support \"AVERAGE\", " "\"SQRT\", \"LAST\" and \"FIRST\"", pooltype)); @@ -499,6 +498,5 @@ template class SequencePoolFunctor; template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu similarity index 93% rename from paddle/fluid/operators/math/sequence_pooling.cu rename to paddle/phi/kernels/funcs/sequence_pooling.cu index e56f0025a0e664c1c093adcbed30184d4c32ea33..4bc4b11692d5c4fe5251904373b91e804da60436 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,14 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/math/sequence_pooling.h" -#include "paddle/fluid/platform/macros.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/macros.h" +#include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template struct MaxPoolFunctor { @@ -213,7 +213,7 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_kernel> @@ -224,7 +224,7 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), nullptr); } else if (pooltype == "SUM") { sequence_pool_kernel> @@ -235,7 +235,7 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), nullptr); } else if (pooltype == "SQRT") { sequence_pool_kernel> @@ -246,7 +246,7 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), nullptr); } else if (pooltype == "LAST") { sequence_pool_kernel> @@ -257,7 +257,7 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), nullptr); } else if (pooltype == "FIRST") { sequence_pool_kernel> @@ -268,10 +268,10 @@ class SequencePoolFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - output->mutable_data(context.GetPlace()), + context.template Alloc(output), nullptr); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(errors::InvalidArgument( "unsupported pooling pooltype: %s. Only support \"MAX\", " "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", pooltype)); @@ -430,7 +430,7 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_grad_kernel> @@ -440,7 +440,7 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), nullptr); } else if (pooltype == "SUM") { sequence_pool_grad_kernel> @@ -450,7 +450,7 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), nullptr); } else if (pooltype == "SQRT") { sequence_pool_grad_kernel> @@ -460,7 +460,7 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), nullptr); } else if (pooltype == "LAST") { sequence_pool_grad_kernel> @@ -470,7 +470,7 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), nullptr); } else if (pooltype == "FIRST") { sequence_pool_grad_kernel> @@ -480,11 +480,11 @@ class SequencePoolGradFunctor { mix_vector.CUDAData(context.GetPlace()), lod.size(), item_dim, - in_grad->mutable_data(context.GetPlace()), + context.template Alloc(in_grad), nullptr); } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(errors::InvalidArgument( "unsupported pooling pooltype: %s. Only support \"MAX\", " "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", pooltype)); @@ -498,6 +498,5 @@ template class SequencePoolFunctor; template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/phi/kernels/funcs/sequence_pooling.h similarity index 80% rename from paddle/fluid/operators/math/sequence_pooling.h rename to paddle/phi/kernels/funcs/sequence_pooling.h index 6a8e943d5d8341d69a530dabbfb55460beae15ea..8602d5e4cfc0011c3db6cdb1356500cdda2ff2e9 100644 --- a/paddle/fluid/operators/math/sequence_pooling.h +++ b/paddle/phi/kernels/funcs/sequence_pooling.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,10 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template class SequencePoolFunctor { @@ -47,6 +44,5 @@ class SequencePoolGradFunctor { const phi::DenseTensor* index = nullptr); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7baf83e75dc3dabc11ea2a14f1760d849c17534c --- /dev/null +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_kernel.h" +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float) {} diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..56633b6b54270e245647041c0c70ef14a0a7a014 --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + T pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + // Do not create index buffer for inference mode + if (pooltype == "MAX" && + (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { + index = max_index; + index->Resize({dims}); + ctx.template Alloc(index); + } + phi::funcs::SequencePoolFunctor pool; + pool(ctx, pooltype, pad_value_, x, out, is_test, index); +} + +} // namespace phi diff --git a/paddle/phi/kernels/sequence_pool_kernel.h b/paddle/phi/kernels/sequence_pool_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f423d685a66524d51917d25b2d84e8a1d3cd0103 --- /dev/null +++ b/paddle/phi/kernels/sequence_pool_kernel.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index); + +} // namespace phi diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c4d6691db4bb90f9fc10341c305ade8992aa2e8 --- /dev/null +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SequencePoolOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sequence_pool", + {"X"}, + {"is_test", "pooltype", "pad_value"}, + {"Out", "MaxIndex"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping); diff --git a/test/cpp/phi/kernels/sequence_pooling_test.cc b/test/cpp/phi/kernels/sequence_pooling_test.cc index 3c12d55ed360f93167b3a33b49dc9ce3ad811d0f..b9a6bda19a2dfa89b4cc30eb578620da2651f166 100644 --- a/test/cpp/phi/kernels/sequence_pooling_test.cc +++ b/test/cpp/phi/kernels/sequence_pooling_test.cc @@ -14,11 +14,10 @@ limitations under the License. */ #include -#include "paddle/fluid/operators/math/sequence_pooling.h" - #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" template void TestSequencePoolingSum(const DeviceContext &context, @@ -78,7 +77,7 @@ void TestSequencePoolingSum(const DeviceContext &context, } // call functor - paddle::operators::math::SequencePoolGradFunctor()( + phi::funcs::SequencePoolGradFunctor()( context, "SUM", out_grad, &in_grad); if (place == phi::CPUPlace()) {