未验证 提交 fe053396 编写于 作者: G gouzil 提交者: GitHub

[phi] Move sequence_pool to phi - Step 3 :sequence_pool_grad_op (#52680)

* [phi] move sequence_pool kernel to phi

* mv kernels impl

* fix parameter error

* clean include

* fix compat filename

* [phi] move fluid sequence_pool_grad to phi

* [phi][compat] sig rm GradVarName

* [phi] fix sequence_pool out type

* [phi] rm impl, add const string

* [phi] fix const str

* fix sequence_pooling cmake

* [phi] mv sequence_pooling_test

* [phi] fix grad sig

* [phi] fix sequence_pool is_test error

* [phi] fix sequence_pooling gpu include

* [phi] mv to impl

* [phi] fix SequencePoolFunctor cu include

* [phi] modify out max_index int32_t

* [phi] add pooltype mapping determine

* [phi] fix sequence_pool_sig

* [phi] fix sequence_pool_sig sum

* [phi] try ci

* [phi] fix max_index optional
上级 182b6f83
...@@ -14,7 +14,6 @@ math_library(sample_prob) ...@@ -14,7 +14,6 @@ math_library(sample_prob)
math_library(sampler DEPS generator) math_library(sampler DEPS generator)
# math_library(math_function DEPS blas dense_tensor tensor) # math_library(math_function DEPS blas dense_tensor tensor)
if(WITH_XPU) if(WITH_XPU)
math_library(beam_search DEPS math_function beam_search_xpu) math_library(beam_search DEPS math_function beam_search_xpu)
else() else()
......
...@@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" #include "paddle/fluid/framework/op_registry.h"
#include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -196,10 +193,3 @@ REGISTER_OPERATOR(sequence_pool, ...@@ -196,10 +193,3 @@ REGISTER_OPERATOR(sequence_pool,
REGISTER_OPERATOR(sequence_pool_grad, REGISTER_OPERATOR(sequence_pool_grad,
ops::SequencePoolGradOp, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInferer); ops::SequencePoolGradOpNoNeedBufferVarsInferer);
PD_REGISTER_STRUCT_KERNEL(sequence_pool_grad,
CPU,
ALL_LAYOUT,
ops::SequencePoolGradKernel,
float,
double) {}
...@@ -30,7 +30,6 @@ register_unity_group( ...@@ -30,7 +30,6 @@ register_unity_group(
sequence_expand_op.cu sequence_expand_op.cu
sequence_mask_op.cu sequence_mask_op.cu
sequence_pad_op.cu sequence_pad_op.cu
sequence_pool_op.cu
sequence_expand_as_op.cu sequence_expand_as_op.cu
sequence_reshape_op.cu sequence_reshape_op.cu
sequence_reverse_op.cu sequence_reverse_op.cu
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sequence_pool_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h"
PD_REGISTER_KERNEL(sequence_pool_grad,
CPU,
ALL_LAYOUT,
phi::SequencePoolGradKernel,
float,
double) {}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -11,8 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
namespace ops = paddle::operators; #include "paddle/phi/kernels/sequence_pool_grad_kernel.h"
PD_REGISTER_STRUCT_KERNEL( #include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h"
sequence_pool_grad, GPU, ALL_LAYOUT, ops::SequencePoolGradKernel, float) {}
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
sequence_pool_grad, GPU, ALL_LAYOUT, phi::SequencePoolGradKernel, float) {}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,37 +13,27 @@ See the License for the specific language governing permissions and ...@@ -13,37 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sequence_pooling.h" #include "paddle/phi/kernels/funcs/sequence_pooling.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T, typename Context>
template <typename T, typename DeviceContext> void SequencePoolGradKernel(const Context& dev_ctx,
class SequencePoolGradKernel : public framework::OpKernel<T> { const DenseTensor& x,
public: const paddle::optional<DenseTensor>& max_index,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& out_grad,
auto* out_g = bool is_test,
context.Input<phi::DenseTensor>(framework::GradVarName("Out")); const std::string& pooltype,
auto* in_g = context.Output<phi::DenseTensor>(framework::GradVarName("X")); float pad_value,
std::string pooltype = context.Attr<std::string>("pooltype"); DenseTensor* x_grad) {
const phi::DenseTensor* index = nullptr; const phi::DenseTensor* index = nullptr;
if (pooltype == "MAX") { if (pooltype == "MAX") {
index = context.Input<phi::DenseTensor>("MaxIndex"); index = max_index.get_ptr();
} }
in_g->mutable_data<T>(context.GetPlace()); dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SequencePoolGradFunctor<DeviceContext, T> pool; phi::funcs::SequencePoolGradFunctor<Context, T> pool;
pool(context.template device_context<DeviceContext>(), pool(dev_ctx, pooltype, out_grad, x_grad, index);
pooltype, }
*out_g,
in_g,
index);
}
};
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -68,7 +68,7 @@ void SequencePoolKernel(const Context& ctx, ...@@ -68,7 +68,7 @@ void SequencePoolKernel(const Context& ctx,
(is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) {
index = max_index; index = max_index;
index->Resize({dims}); index->Resize({dims});
ctx.template Alloc<int>(index); ctx.template Alloc<int32_t>(index);
} }
phi::funcs::SequencePoolFunctor<Context, T> pool; phi::funcs::SequencePoolFunctor<Context, T> pool;
pool(ctx, pooltype, pad_value_, x, out, is_test, index); pool(ctx, pooltype, pad_value_, x, out, is_test, index);
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SequencePoolGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& max_index,
const DenseTensor& out_grad,
bool is_test,
const std::string& pooltype,
float pad_value,
DenseTensor* x_grad);
} // namespace phi
...@@ -21,6 +21,16 @@ KernelSignature SequencePoolOpArgumentMapping( ...@@ -21,6 +21,16 @@ KernelSignature SequencePoolOpArgumentMapping(
{"Out", "MaxIndex"}); {"Out", "MaxIndex"});
} }
KernelSignature SequencePoolGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sequence_pool_grad",
{"X", "MaxIndex", "Out@GRAD"},
{"is_test", "pooltype", "pad_value"},
{"X@GRAD"});
}
} // namespace phi } // namespace phi
PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sequence_pool_grad,
phi::SequencePoolGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册