未验证 提交 b281b221 编写于 作者: G gouzil 提交者: GitHub

[phi] move sequence_pool to phi - Step 2 : sequence_pool_op (#52750)

* [phi] move sequence_pool kernel to phi

* [phi] mv sequence_pooling to phi funcs

* [phi] mv sequence_pooling_test

* [phi] RollBACK `paddle/fluid/operators/sequence_ops/sequence_pool_op.cc`

* [phi][funcs] fix mutable_data

* [phi][funcs] fix mutable_data
上级 3bac6264
......@@ -15,7 +15,6 @@ math_library(sampler DEPS generator)
# math_library(math_function DEPS blas dense_tensor tensor)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
if(WITH_XPU)
math_library(beam_search DEPS math_function beam_search_xpu)
else()
......
......@@ -196,9 +196,6 @@ REGISTER_OPERATOR(sequence_pool,
REGISTER_OPERATOR(sequence_pool_grad,
ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(sequence_pool,
ops::SequencePoolKernel<phi::CPUContext, float>,
ops::SequencePoolKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(sequence_pool_grad,
ops::SequencePoolGradKernel<phi::CPUContext, float>,
......
......@@ -14,7 +14,5 @@ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sequence_pool,
ops::SequencePoolKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(sequence_pool_grad,
ops::SequencePoolGradKernel<phi::GPUContext, float>);
......@@ -17,81 +17,12 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SequencePoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
T pad_value = static_cast<T>(context.Attr<float>("pad_value"));
auto dims = in->dims();
auto lod = in->lod();
auto lod_level = lod.size();
// InferShape by lod
PADDLE_ENFORCE_GT(lod_level,
0,
platform::errors::InvalidArgument(
"Input(X) phi::DenseTensor of SequencePoolOp "
"does not contain LoD information."));
PADDLE_ENFORCE_LE(lod_level,
2UL,
platform::errors::InvalidArgument(
"The lod level of input shall be no more than 2."
"Received lod level is %d.",
lod_level));
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[lod_level - 1].size() - 1),
platform::errors::InvalidArgument(
"The first dimension of Input(X) must be large than batch size."
"But received first dimension of Input(X) is %d, while batch"
"size is %d.",
dims[0],
static_cast<int64_t>(lod[lod_level - 1].size() - 1)));
if (lod_level > 1UL) {
PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1],
lod[1].size() - 1,
platform::errors::InvalidArgument(
"The input lod information is illegal."));
framework::LoD out_lod;
out_lod.push_back(lod[0]);
out->set_lod(out_lod);
}
dims[0] = lod[lod_level - 1].size() - 1;
out->Resize({dims});
out->mutable_data<T>(context.GetPlace());
phi::DenseTensor* index = nullptr;
bool is_test =
context.HasAttr("is_test") ? context.Attr<bool>("is_test") : false;
// Do not create index buffer for inference mode
if (pooltype == "MAX" &&
(is_test == false ||
platform::is_cpu_place(context.GetPlace()) == false)) {
index = context.Output<phi::DenseTensor>("MaxIndex");
index->Resize({dims});
index->mutable_data<int>(context.GetPlace());
}
math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(),
pooltype,
pad_value,
*in,
out,
is_test,
index);
}
};
template <typename DeviceContext, typename T>
class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
......@@ -105,7 +36,7 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
index = context.Input<phi::DenseTensor>("MaxIndex");
}
in_g->mutable_data<T>(context.GetPlace());
math::SequencePoolGradFunctor<DeviceContext, T> pool;
phi::funcs::SequencePoolGradFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(),
pooltype,
*out_g,
......
......@@ -74,6 +74,7 @@ set(COMMON_KERNEL_DEPS
phi_dynload_warpctc
phi_dynload_warprnnt
sequence_padding
sequence_pooling
sequence_scale
fft
phi_data_layout_transform
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sequence_pool_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h"
PD_REGISTER_KERNEL(
sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) {}
......@@ -25,6 +25,7 @@ math_library(maxouting)
math_library(matrix_bit_code)
math_library(sequence_scale)
math_library(sequence_padding DEPS lod_utils)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
cc_library(
phi_data_layout_transform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
#include <string>
......@@ -21,9 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <typename T,
int MajorType = Eigen::RowMajor,
......@@ -47,13 +46,13 @@ class MaxSeqPoolFunctor {
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of input shall be greater than 1, but got "
"the rank is %ld. Please check the input value",
in_dims.size()));
PADDLE_ENFORCE_GT(out_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of output shall be greater than 1, but got "
"the rank is %ld. Please check the input value",
out_dims.size()));
......@@ -61,7 +60,7 @@ class MaxSeqPoolFunctor {
PADDLE_ENFORCE_EQ(
in_dims[i],
out_dims[i],
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld "
"== %ld, but got %ld != %ld. Please check the input value.",
in_dims[i],
......@@ -72,7 +71,7 @@ class MaxSeqPoolFunctor {
PADDLE_ENFORCE_EQ(
idx_dims,
out_dims,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The dimension of index and output shall be same. Expected %ld == "
"%ld, but got %ld != %ld. Please check the input value.",
idx_dims,
......@@ -125,13 +124,13 @@ class MaxSeqPoolFunctor<T, true> {
auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of input shall be greater than 1, but got "
"%ld <= 1. Please check the input value.",
in_dims.size()));
PADDLE_ENFORCE_GT(out_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of output shall be greater than 1, but got "
"%ld <= 1. Please check the input value.",
out_dims.size()));
......@@ -139,7 +138,7 @@ class MaxSeqPoolFunctor<T, true> {
PADDLE_ENFORCE_EQ(
in_dims[i],
out_dims[i],
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld "
"== %ld, but got %ld != %ld. Please check the input value.",
in_dims[i],
......@@ -186,20 +185,20 @@ class MaxSeqPoolGradFunctor {
auto idx_dims = index.dims();
PADDLE_ENFORCE_GT(og_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of output@Grad shall be greater than 1, "
"but got %ld <= 1. Please check the input value.",
og_dims.size()));
PADDLE_ENFORCE_GT(ig_dims.size(),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The rank of input@Grad shall be greater than 1, but "
"got %ld <= 1. Please check the input value.",
ig_dims.size()));
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i],
ig_dims[i],
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The dimension of input@Grad and output@Grad shall "
"be same. Expected %ld == %ld, but got %ld != %ld. "
"Please check the input value.",
......@@ -211,7 +210,7 @@ class MaxSeqPoolGradFunctor {
PADDLE_ENFORCE_EQ(
idx_dims,
og_dims,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The dimension of index and output@Grad shall be same. Expected "
"%ld == %ld, but got %ld != %ld. Please check the input value.",
idx_dims,
......@@ -317,7 +316,7 @@ class SumSeqPoolGradFunctor {
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE_EQ(in_w,
out_w,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The feature size of input@Grad and output@Grad "
"shall be same. Expected %ld == %ld, but got %ld != "
"%ld. Please check the input value.",
......@@ -326,7 +325,7 @@ class SumSeqPoolGradFunctor {
in_w,
out_w));
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
T* in_g_data = context.template Alloc<T>(in_grad);
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
......@@ -354,21 +353,21 @@ class SequencePoolFunctor<phi::CPUContext, T> {
phi::DenseTensor* index = nullptr) {
if (pooltype == "MAX") {
if (is_test) {
math::MaxSeqPoolFunctor<T, true> max_pool;
phi::funcs::MaxSeqPoolFunctor<T, true> max_pool;
max_pool(context, input, pad_value, output, index);
} else {
math::MaxSeqPoolFunctor<T, false> max_pool;
phi::funcs::MaxSeqPoolFunctor<T, false> max_pool;
max_pool(context, input, pad_value, output, index);
}
return;
}
if (pooltype == "LAST") {
math::LastSeqPoolFunctor<T> last_pool;
phi::funcs::LastSeqPoolFunctor<T> last_pool;
last_pool(context, input, pad_value, output);
return;
}
if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool;
phi::funcs::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, pad_value, output);
return;
}
......@@ -377,17 +376,17 @@ class SequencePoolFunctor<phi::CPUContext, T> {
if (pooltype == "SUM") {
auto place = context.GetPlace();
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place),
place == phi::CPUPlace(),
true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"Sequence_pool should run on CPU Device when pooltype is SUM"));
const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place);
T* dst = context.template Alloc<T>(output);
phi::jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
phi::jit::SeqPoolType::kSum);
auto seqpool = phi::jit::KernelFuncs<phi::jit::SeqPoolTuple<T>,
platform::CPUPlace>::Cache()
phi::CPUPlace>::Cache()
.At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
......@@ -424,7 +423,7 @@ class SequencePoolFunctor<phi::CPUContext, T> {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"AVERAGE\" and "
"\"SQRT\"",
pooltype));
......@@ -443,7 +442,7 @@ class SequencePoolGradFunctor<phi::CPUContext, T> {
/* max pool has index */
const phi::DenseTensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<T> max_pool_grad;
phi::funcs::MaxSeqPoolGradFunctor<T> max_pool_grad;
max_pool_grad(context, out_grad, *index, in_grad);
return;
}
......@@ -455,7 +454,7 @@ class SequencePoolGradFunctor<phi::CPUContext, T> {
}
if (pooltype == "SUM") {
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
phi::funcs::SumSeqPoolGradFunctor<T> sum_pool_grad;
sum_pool_grad(context, out_grad, in_grad);
return;
}
......@@ -485,7 +484,7 @@ class SequencePoolGradFunctor<phi::CPUContext, T> {
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"AVERAGE\", "
"\"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
......@@ -499,6 +498,5 @@ template class SequencePoolFunctor<phi::CPUContext, double>;
template class SequencePoolGradFunctor<phi::CPUContext, float>;
template class SequencePoolGradFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,14 +15,14 @@ limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/mixed_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <typename T>
struct MaxPoolFunctor {
......@@ -213,7 +213,7 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_kernel<T, AvgPoolFunctor<T>>
......@@ -224,7 +224,7 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
nullptr);
} else if (pooltype == "SUM") {
sequence_pool_kernel<T, SumPoolFunctor<T>>
......@@ -235,7 +235,7 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_kernel<T, SqrtPoolFunctor<T>>
......@@ -246,7 +246,7 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
nullptr);
} else if (pooltype == "LAST") {
sequence_pool_kernel<T, LastPoolFunctor<T>>
......@@ -257,7 +257,7 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_kernel<T, FirstPoolFunctor<T>>
......@@ -268,10 +268,10 @@ class SequencePoolFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
output->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(output),
nullptr);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"MAX\", "
"\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
......@@ -430,7 +430,7 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_grad_kernel<T, AvgPoolGradFunctor<T>>
......@@ -440,7 +440,7 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
nullptr);
} else if (pooltype == "SUM") {
sequence_pool_grad_kernel<T, SumPoolGradFunctor<T>>
......@@ -450,7 +450,7 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_grad_kernel<T, SqrtPoolGradFunctor<T>>
......@@ -460,7 +460,7 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
nullptr);
} else if (pooltype == "LAST") {
sequence_pool_grad_kernel<T, LastPoolGradFunctor<T>>
......@@ -470,7 +470,7 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_grad_kernel<T, FirstPoolGradFunctor<T>>
......@@ -480,11 +480,11 @@ class SequencePoolGradFunctor<phi::GPUContext, T> {
mix_vector.CUDAData(context.GetPlace()),
lod.size(),
item_dim,
in_grad->mutable_data<T>(context.GetPlace()),
context.template Alloc<T>(in_grad),
nullptr);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"MAX\", "
"\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
......@@ -498,6 +498,5 @@ template class SequencePoolFunctor<phi::GPUContext, double>;
template class SequencePoolGradFunctor<phi::GPUContext, float>;
template class SequencePoolGradFunctor<phi::GPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,13 +15,10 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace operators {
namespace math {
namespace phi {
namespace funcs {
template <typename DeviceContext, typename T>
class SequencePoolFunctor {
......@@ -47,6 +44,5 @@ class SequencePoolGradFunctor {
const phi::DenseTensor* index = nullptr);
};
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sequence_pool_kernel.h"
#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
namespace phi {
template <typename T, typename Context>
void SequencePoolKernel(const Context& ctx,
const DenseTensor& x,
bool is_test,
const std::string& pooltype,
float pad_value,
DenseTensor* out,
DenseTensor* max_index) {
T pad_value_ = static_cast<T>(pad_value);
auto dims = x.dims();
auto lod = x.lod();
auto lod_level = lod.size();
// InferShape by lod
PADDLE_ENFORCE_GT(
lod_level,
0,
errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp "
"does not contain LoD information."));
PADDLE_ENFORCE_LE(
lod_level,
2UL,
errors::InvalidArgument("The lod level of input shall be no more than 2."
"Received lod level is %d.",
lod_level));
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[lod_level - 1].size() - 1),
errors::InvalidArgument(
"The first dimension of Input(X) must be large than batch size."
"But received first dimension of Input(X) is %d, while batch"
"size is %d.",
dims[0],
static_cast<int64_t>(lod[lod_level - 1].size() - 1)));
if (lod_level > 1UL) {
PADDLE_ENFORCE_EQ(
lod[0][lod[0].size() - 1],
lod[1].size() - 1,
errors::InvalidArgument("The input lod information is illegal."));
phi::LoD out_lod;
out_lod.push_back(lod[0]);
out->set_lod(out_lod);
}
dims[0] = lod[lod_level - 1].size() - 1;
out->Resize({dims});
ctx.template Alloc<T>(out);
phi::DenseTensor* index = nullptr;
// Do not create index buffer for inference mode
if (pooltype == "MAX" &&
(is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) {
index = max_index;
index->Resize({dims});
ctx.template Alloc<int>(index);
}
phi::funcs::SequencePoolFunctor<Context, T> pool;
pool(ctx, pooltype, pad_value_, x, out, is_test, index);
}
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SequencePoolKernel(const Context& ctx,
const DenseTensor& x,
bool is_test,
const std::string& pooltype,
float pad_value,
DenseTensor* out,
DenseTensor* max_index);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SequencePoolOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sequence_pool",
{"X"},
{"is_test", "pooltype", "pad_value"},
{"Out", "MaxIndex"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping);
......@@ -14,11 +14,10 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
template <typename DeviceContext, typename T>
void TestSequencePoolingSum(const DeviceContext &context,
......@@ -78,7 +77,7 @@ void TestSequencePoolingSum(const DeviceContext &context,
}
// call functor
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
phi::funcs::SequencePoolGradFunctor<DeviceContext, T>()(
context, "SUM", out_grad, &in_grad);
if (place == phi::CPUPlace()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册