未验证 提交 a07f19ee 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] Move segment_pool to phi. (#40099)

* move segment_pool to phi.

* mark summed ids as optional tensor.

* fix as reviews.
上级 548f2be4
...@@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP ...@@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling segment_pooling executor device_memory_aligment generator) sequence_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve)
......
...@@ -46,7 +46,6 @@ math_library(vol2col) ...@@ -46,7 +46,6 @@ math_library(vol2col)
math_library(prelu) math_library(prelu)
math_library(bert_encoder_functor) math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function) math_library(tree2col DEPS math_function)
math_library(segment_pooling)
math_library(matrix_solve) math_library(matrix_solve)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
......
...@@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel { ...@@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool");
OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds",
"SegmentPool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool");
auto dims = ctx->GetInputDim("X");
dims[0] = -1;
ctx->SetOutputDim("Out", dims);
if (ctx->Attrs().Get<std::string>("pooltype") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds",
"SegmentPool");
ctx->SetOutputDim("SummedIds", {-1, 1});
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(segment_pool, SegmentPoolInferShapeFunctor,
PD_INFER_META(phi::SegmentPoolInferMeta));
REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker, REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker,
ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>, ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>,
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>); ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>,
SegmentPoolInferShapeFunctor);
REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp); REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp);
REGISTER_OP_CPU_KERNEL(
segment_pool,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
segment_pool_grad,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
segment_pool,
ops::SegmentPoolKernel<paddle::platform::CUDADeviceContext, float>,
ops::SegmentPoolKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
segment_pool_grad,
ops::SegmentPoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SegmentPoolGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T, typename IndexT>
void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
auto* input = context.Input<Tensor>("X");
auto* segment = context.Input<Tensor>("SegmentIds");
auto* output = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* summed_ids = nullptr;
int64_t num_indices = segment->numel();
PADDLE_ENFORCE_EQ(
num_indices, input->dims()[0],
platform::errors::InvalidArgument(
"Segment_ids should be the same size as dimension 0 of input X."));
PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0],
platform::errors::InvalidArgument(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s].",
segment->dims()));
if (input->numel() == 0 || segment->numel() == 0) {
return;
}
bool cpu_place = context.GetPlace().GetType() == phi::AllocationType::CPU;
if (cpu_place) {
auto dims = input->dims();
auto* segment_ids = segment->data<IndexT>();
dims[0] = static_cast<int64_t>(segment_ids[segment->numel() - 1] + 1);
PADDLE_ENFORCE_GT(
dims[0], 0,
platform::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", dims[0]));
output->Resize({dims});
output->mutable_data<T>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) {
Tensor length;
length.mutable_data<IndexT>(phi::make_ddim({1}), platform::CPUPlace());
IndexT* length_data = length.data<IndexT>();
const IndexT* segment_ids = segment->data<IndexT>();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
hipMemcpyDeviceToHost));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
cudaMemcpyDeviceToHost));
#endif
IndexT length_host = length_data[0];
length_host++;
PADDLE_ENFORCE_GT(
length_host, 0,
platform::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", length_data[0]));
auto dims = input->dims();
dims[0] = static_cast<int64_t>(length_host);
output->Resize({dims});
output->mutable_data<T>(context.GetPlace());
T init_value = 0;
if (pooltype == "MAX") {
init_value = static_cast<T>(-FLT_MAX);
} else if (pooltype == "MIN") {
init_value = static_cast<T>(FLT_MAX);
}
phi::funcs::SetConstant<DeviceContext, T> setconst;
auto& dev_ctx = context.template device_context<DeviceContext>();
setconst(dev_ctx, output, static_cast<T>(init_value));
// the gpu kernel of mean pool record the counts of segment_ids
if (pooltype == "MEAN") {
summed_ids = context.Output<Tensor>("SummedIds");
summed_ids->Resize({dims[0], 1});
summed_ids->mutable_data<T>(context.GetPlace());
setconst(dev_ctx, summed_ids, static_cast<T>(1e-12));
}
}
#endif
SegmentPoolFunctor<DeviceContext, T, IndexT> pool;
pool(context.template device_context<DeviceContext>(), *input, *segment,
output, summed_ids, pooltype);
}
template <typename DeviceContext, typename T>
class SegmentPoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* segment = context.Input<Tensor>("SegmentIds");
auto index_type = framework::TransToProtoVarType(segment->dtype());
if (index_type == framework::proto::VarType::INT32) {
SegmentKernelLaunchHelper<DeviceContext, T, int>(context);
} else if (index_type == framework::proto::VarType::INT64) {
SegmentKernelLaunchHelper<DeviceContext, T, int64_t>(context);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class SegmentPoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Input<Tensor>("Out");
auto* segment = context.Input<Tensor>("SegmentIds");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype");
const Tensor* summed_ids = nullptr;
if (pooltype == "MEAN") {
summed_ids = context.Input<Tensor>("SummedIds");
}
in_g->mutable_data<T>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, in_g, static_cast<T>(0));
auto index_type = framework::TransToProtoVarType(segment->dtype());
if (index_type == framework::proto::VarType::INT32) {
SegmentPoolGradFunctor<DeviceContext, T, int> pool;
pool(context.template device_context<DeviceContext>(), *input, *output,
*out_g, *segment, in_g, summed_ids, pooltype);
} else if (index_type == framework::proto::VarType::INT64) {
SegmentPoolGradFunctor<DeviceContext, T, int64_t> pool;
pool(context.template device_context<DeviceContext>(), *input, *output,
*out_g, *segment, in_g, summed_ids, pooltype);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -236,7 +236,6 @@ register_unity_group(cc ...@@ -236,7 +236,6 @@ register_unity_group(cc
scatter_nd_add_op.cc scatter_nd_add_op.cc
scatter_op.cc scatter_op.cc
seed_op.cc seed_op.cc
segment_pool_op.cc
select_input_op.cc select_input_op.cc
select_output_op.cc) select_output_op.cc)
register_unity_group(cc register_unity_group(cc
...@@ -496,8 +495,7 @@ register_unity_group(cu ...@@ -496,8 +495,7 @@ register_unity_group(cu
scale_op.cu scale_op.cu
scatter_nd_add_op.cu scatter_nd_add_op.cu
scatter_op.cu scatter_op.cu
seed_op.cu seed_op.cu)
segment_pool_op.cu)
register_unity_group(cu register_unity_group(cu
roi_pool_op.cu roi_pool_op.cu
selu_op.cu selu_op.cu
......
...@@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { ...@@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->share_meta(x); out->share_meta(x);
} }
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
MetaTensor* out,
MetaTensor* summed_ids,
MetaConfig config) {
auto dims = x.dims();
dims[0] = -1;
out->set_dims(dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
if (pooltype == "MEAN") {
summed_ids->set_dims({-1, 1});
summed_ids->set_dtype(x.dtype());
summed_ids->set_layout(x.layout());
}
}
void BCELossInferMeta(const MetaTensor& input, void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
MetaTensor* out, MetaTensor* out,
......
...@@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x, ...@@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x,
MetaTensor* out); MetaTensor* out);
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
MetaTensor* out,
MetaTensor* summed_ids,
MetaConfig config = MetaConfig());
void BCELossInferMeta(const MetaTensor& input, void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
MetaTensor* out, MetaTensor* out,
......
...@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) ...@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used. # Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies. # These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here. # In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel) set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
...@@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat ...@@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(segment_pool_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)
kernel_library(segment_pool_grad_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)
# 4. auto parse and build kernel targets by cmake # 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(segment_pool_grad,
CPU,
ALL_LAYOUT,
phi::SegmentPoolGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/segment_pool_kernel.h"
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
...@@ -4,6 +4,7 @@ add_subdirectory(lapack) ...@@ -4,6 +4,7 @@ add_subdirectory(lapack)
add_subdirectory(detail) add_subdirectory(detail)
math_library(math_function DEPS blas dense_tensor tensor) math_library(math_function DEPS blas dense_tensor tensor)
math_library(segment_pooling)
math_library(sequence2batch) math_library(sequence2batch)
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
......
...@@ -12,45 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,45 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/segment_pooling.h" #include "paddle/phi/kernels/funcs/segment_pooling.h"
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h"
namespace paddle { #include "paddle/phi/backends/cpu/cpu_context.h"
namespace operators { #include "paddle/phi/kernels/funcs/eigen/common.h"
using Tensor = framework::Tensor; namespace phi {
namespace funcs {
using Tensor = DenseTensor;
template <typename T, typename IndexT> template <typename T, typename IndexT>
class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> { class SegmentPoolFunctor<phi::CPUContext, T, IndexT> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const phi::CPUContext& dev_ctx,
const framework::Tensor& input, const DenseTensor& input,
const framework::Tensor& segments, framework::Tensor* output, const DenseTensor& segments,
framework::Tensor* index, DenseTensor* output,
DenseTensor* index,
const std::string pooltype = "SUM") { const std::string pooltype = "SUM") {
const IndexT* segment_ids = segments.data<IndexT>(); const IndexT* segment_ids = segments.data<IndexT>();
auto curent_id = segment_ids[0]; auto curent_id = segment_ids[0];
int64_t last_idx = 0; int64_t last_idx = 0;
int64_t w = input.numel() / input.dims()[0]; int64_t w = input.numel() / input.dims()[0];
auto& place = *context.eigen_device(); auto& place = *dev_ctx.eigen_device();
for (int64_t idx = 1; idx <= segments.numel(); ++idx) { for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
if (idx < segments.numel()) { if (idx < segments.numel()) {
if (segment_ids[idx] == curent_id) continue; if (segment_ids[idx] == curent_id) continue;
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, PADDLE_ENFORCE_GE(segment_ids[idx],
platform::errors::InvalidArgument( curent_id,
phi::errors::InvalidArgument(
"The segment ids should be sorted, but got " "The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d.", "segment_ids[%d]:%d > segment_ids[%d]:%d.",
idx - 1, curent_id, idx, segment_ids[idx])); idx - 1,
curent_id,
idx,
segment_ids[idx]));
} }
Tensor out_t = output->Slice(curent_id, curent_id + 1); Tensor out_t = output->Slice(curent_id, curent_id + 1);
Tensor in_t = input.Slice(last_idx, idx); Tensor in_t = input.Slice(last_idx, idx);
int64_t h = idx - last_idx; int64_t h = idx - last_idx;
auto in_e = framework::EigenMatrix<T>::From(in_t, phi::make_ddim({h, w})); auto in_e = EigenMatrix<T>::From(in_t, phi::make_ddim({h, w}));
auto out_e = framework::EigenVector<T>::Flatten(out_t); auto out_e = EigenVector<T>::Flatten(out_t);
auto reduce_dim = Eigen::array<int, 1>({{0}}); auto reduce_dim = Eigen::array<int, 1>({{0}});
if (pooltype == "MEAN") { if (pooltype == "MEAN") {
...@@ -62,7 +69,7 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> { ...@@ -62,7 +69,7 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
} else if (pooltype == "MIN") { } else if (pooltype == "MIN") {
out_e.device(place) = in_e.minimum(reduce_dim); out_e.device(place) = in_e.minimum(reduce_dim);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s.", "available, but got %s.",
pooltype)); pooltype));
...@@ -75,36 +82,41 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> { ...@@ -75,36 +82,41 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
}; };
template <typename T, typename IndexT> template <typename T, typename IndexT>
class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> { class SegmentPoolGradFunctor<phi::CPUContext, T, IndexT> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const phi::CPUContext& dev_ctx,
const framework::Tensor& input, const DenseTensor& input,
const framework::Tensor& output, const DenseTensor& output,
const framework::Tensor& out_grad, const DenseTensor& out_grad,
const framework::Tensor& segments, framework::Tensor* in_grad, const DenseTensor& segments,
const framework::Tensor* index = nullptr, DenseTensor* in_grad,
paddle::optional<const DenseTensor&> index,
const std::string pooltype = "SUM") { const std::string pooltype = "SUM") {
const IndexT* segment_ids = segments.data<IndexT>(); const IndexT* segment_ids = segments.data<IndexT>();
auto& place = *context.eigen_device(); auto& place = *dev_ctx.eigen_device();
auto curent_id = segment_ids[0]; auto curent_id = segment_ids[0];
int64_t last_idx = 0; int64_t last_idx = 0;
int64_t w = in_grad->numel() / in_grad->dims()[0]; int64_t w = in_grad->numel() / in_grad->dims()[0];
for (int64_t idx = 1; idx <= segments.numel(); ++idx) { for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
if (idx < segments.numel()) { if (idx < segments.numel()) {
if (segment_ids[idx] == curent_id) continue; if (segment_ids[idx] == curent_id) continue;
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id, PADDLE_ENFORCE_GE(segment_ids[idx],
platform::errors::InvalidArgument( curent_id,
phi::errors::InvalidArgument(
"The segment ids should be sorted, but got " "The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d.", "segment_ids[%d]:%d > segment_ids[%d]:%d.",
idx - 1, curent_id, idx, segment_ids[idx])); idx - 1,
curent_id,
idx,
segment_ids[idx]));
} }
Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1); Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1);
Tensor in_g_t = in_grad->Slice(last_idx, idx); Tensor in_g_t = in_grad->Slice(last_idx, idx);
int64_t h = idx - last_idx; int64_t h = idx - last_idx;
auto in_g_e = framework::EigenMatrix<T>::From(in_g_t, {h, w}); auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = framework::EigenMatrix<T>::From(out_g_t, {1, w}); auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, 1); Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "MEAN") { if (pooltype == "MEAN") {
...@@ -114,13 +126,13 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> { ...@@ -114,13 +126,13 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
} else if (pooltype == "MAX" || pooltype == "MIN") { } else if (pooltype == "MAX" || pooltype == "MIN") {
Tensor out_t = output.Slice(curent_id, curent_id + 1); Tensor out_t = output.Slice(curent_id, curent_id + 1);
Tensor in_t = input.Slice(last_idx, idx); Tensor in_t = input.Slice(last_idx, idx);
auto in_e = framework::EigenMatrix<T>::From(in_t, {h, w}); auto in_e = EigenMatrix<T>::From(in_t, {h, w});
auto out_e = framework::EigenMatrix<T>::From(out_t, {1, w}); auto out_e = EigenMatrix<T>::From(out_t, {1, w});
in_g_e.device(place) = in_g_e.device(place) =
(in_e == out_e.broadcast(bcast)).template cast<T>() * (in_e == out_e.broadcast(bcast)).template cast<T>() *
out_g_e.broadcast(bcast); out_g_e.broadcast(bcast);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN " "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s.", "available, but got %s.",
pooltype)); pooltype));
...@@ -132,7 +144,7 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> { ...@@ -132,7 +144,7 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
} }
}; };
using CPU = platform::CPUDeviceContext; using CPU = phi::CPUContext;
template class SegmentPoolFunctor<CPU, float, int>; template class SegmentPoolFunctor<CPU, float, int>;
template class SegmentPoolFunctor<CPU, float, int64_t>; template class SegmentPoolFunctor<CPU, float, int64_t>;
template class SegmentPoolFunctor<CPU, double, int>; template class SegmentPoolFunctor<CPU, double, int>;
...@@ -142,5 +154,5 @@ template class SegmentPoolGradFunctor<CPU, float, int64_t>; ...@@ -142,5 +154,5 @@ template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template class SegmentPoolGradFunctor<CPU, double, int>; template class SegmentPoolGradFunctor<CPU, double, int>;
template class SegmentPoolGradFunctor<CPU, double, int64_t>; template class SegmentPoolGradFunctor<CPU, double, int64_t>;
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
...@@ -14,33 +14,36 @@ limitations under the License. */ ...@@ -14,33 +14,36 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
template <typename DeviceContext, typename T, typename IndexT> template <typename Context, typename T, typename IndexT>
class SegmentPoolFunctor { class SegmentPoolFunctor {
public: public:
/* mean pool has summed_ids output */ /* mean pool has summed_ids output */
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const Context& dev_ctx,
const framework::Tensor& segments, framework::Tensor* output, const DenseTensor& input,
framework::Tensor* summed_ids = nullptr, const DenseTensor& segments,
DenseTensor* output,
DenseTensor* summed_ids = nullptr,
const std::string pooltype = "SUM"); const std::string pooltype = "SUM");
}; };
template <typename DeviceContext, typename T, typename IndexT> template <typename Context, typename T, typename IndexT>
class SegmentPoolGradFunctor { class SegmentPoolGradFunctor {
public: public:
/* mean pool has summed_ids output */ /* mean pool has summed_ids output */
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const Context& dev_ctx,
const framework::Tensor& output, const DenseTensor& input,
const framework::Tensor& out_grad, const DenseTensor& output,
const framework::Tensor& segments, framework::Tensor* in_grad, const DenseTensor& out_grad,
const framework::Tensor* summed_ids = nullptr, const DenseTensor& segments,
DenseTensor* in_grad,
paddle::optional<const DenseTensor&> summed_ids,
const std::string pooltype = "SUM"); const std::string pooltype = "SUM");
}; };
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(segment_pool_grad,
GPU,
ALL_LAYOUT,
phi::SegmentPoolGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
namespace phi {
template <typename T, typename Context>
void SegmentPoolGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& segment_ids,
const DenseTensor& out,
paddle::optional<const DenseTensor&> summed_ids,
const DenseTensor& out_grad,
const std::string& pooltype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
auto index_type = segment_ids.type();
if (index_type == DataType::INT32) {
phi::funcs::SegmentPoolGradFunctor<Context, T, int> pool;
pool(dev_ctx, x, out, out_grad, segment_ids, x_grad, summed_ids, pooltype);
} else if (index_type == DataType::INT64) {
phi::funcs::SegmentPoolGradFunctor<Context, T, int64_t> pool;
pool(dev_ctx, x, out, out_grad, segment_ids, x_grad, summed_ids, pooltype);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void SegmentKernelLaunchHelper(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& segment_ids,
const std::string& pooltype,
DenseTensor* out,
DenseTensor* summed_ids) {
int64_t num_indices = segment_ids.numel();
PADDLE_ENFORCE_EQ(
num_indices,
x.dims()[0],
phi::errors::InvalidArgument(
"Segment_ids should be the same size as dimension 0 of input X."));
PADDLE_ENFORCE_EQ(num_indices,
segment_ids.dims()[0],
phi::errors::InvalidArgument(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s].",
segment_ids.dims()));
if (x.numel() == 0 || segment_ids.numel() == 0) {
return;
}
bool cpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU;
if (cpu_place) {
auto dims = x.dims();
auto* segment_ids_ptr = segment_ids.data<IndexT>();
dims[0] =
static_cast<int64_t>(segment_ids_ptr[segment_ids.numel() - 1] + 1);
PADDLE_ENFORCE_GT(
dims[0],
0,
phi::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", dims[0]));
out->Resize({dims});
dev_ctx.template Alloc<T>(out);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, out, static_cast<T>(0));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) {
DenseTensor length;
length.Resize(phi::make_ddim({1}));
IndexT* length_data = dev_ctx.template HostAlloc<IndexT>(&length);
const IndexT* segment_ids_ptr = segment_ids.data<IndexT>();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(length_data,
segment_ids_ptr + num_indices - 1,
sizeof(IndexT),
hipMemcpyDeviceToHost));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(length_data,
segment_ids_ptr + num_indices - 1,
sizeof(IndexT),
cudaMemcpyDeviceToHost));
#endif
IndexT length_host = length_data[0];
length_host++;
PADDLE_ENFORCE_GT(
length_host,
0,
phi::errors::InvalidArgument(
"Segment ids must be >= 0, but got last id %d", length_data[0]));
auto dims = x.dims();
dims[0] = static_cast<int64_t>(length_host);
out->Resize({dims});
dev_ctx.template Alloc<T>(out);
T init_value = 0;
if (pooltype == "MAX") {
init_value = static_cast<T>(-FLT_MAX);
} else if (pooltype == "MIN") {
init_value = static_cast<T>(FLT_MAX);
}
phi::funcs::SetConstant<Context, T> setconst;
setconst(dev_ctx, out, static_cast<T>(init_value));
// the gpu kernel of mean pool record the counts of segment_ids
if (pooltype == "MEAN") {
summed_ids->Resize({dims[0], 1});
dev_ctx.template Alloc<T>(summed_ids);
setconst(dev_ctx, summed_ids, static_cast<T>(1e-12));
}
}
#endif
phi::funcs::SegmentPoolFunctor<Context, T, IndexT> pool;
pool(dev_ctx, x, segment_ids, out, summed_ids, pooltype);
}
template <typename T, typename Context>
void SegmentPoolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& segment_ids,
const std::string& pooltype,
DenseTensor* out,
DenseTensor* summed_ids) {
auto index_type = segment_ids.dtype();
if (index_type == DataType::INT32) {
SegmentKernelLaunchHelper<Context, T, int>(
dev_ctx, x, segment_ids, pooltype, out, summed_ids);
} else if (index_type == DataType::INT64) {
SegmentKernelLaunchHelper<Context, T, int64_t>(
dev_ctx, x, segment_ids, pooltype, out, summed_ids);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported index type, Expected int, int64, but got %s.",
index_type));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SegmentPoolGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& segment_ids,
const DenseTensor& out,
paddle::optional<const DenseTensor&> summed_ids,
const DenseTensor& out_grad,
const std::string& pooltype,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SegmentPoolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& segment_ids,
const std::string& pooltype,
DenseTensor* out,
DenseTensor* summed_ids);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SegmentPoolGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"segment_pool_grad",
{
"X", "SegmentIds", "Out", "SummedIds", GradVarName("Out"),
},
{"pooltype"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(segment_pool_grad,
phi::SegmentPoolGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册