diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc index 9f0ef19bd6299cae2f1538377b58a3e97eafe7c0..0602c73db6bbc92727c5a5d7673a4d0c20265458 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,80 +26,6 @@ namespace operators { class LogsumexpOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp"); - auto x_dims = ctx->GetInputDim("X"); - auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 4, - platform::errors::InvalidArgument( - "The input tensor X's dimensions of logsumexp " - "should be less or equal than 4. But received X's " - "dimensions = %d, X's shape = [%s].", - x_rank, x_dims)); - auto axis = ctx->Attrs().Get>("axis"); - PADDLE_ENFORCE_GT( - axis.size(), 0, - platform::errors::InvalidArgument( - "The size of axis of logsumexp " - "should be greater than 0. But received the size of axis " - "of logsumexp is %d.", - axis.size())); - - for (size_t i = 0; i < axis.size(); i++) { - PADDLE_ENFORCE_LT(axis[i], x_rank, - platform::errors::InvalidArgument( - "axis[%d] should be in the " - "range [-D, D), where D is the dimensions of X and " - "D is %d. But received axis[%d] = %d.", - i, x_rank, i, axis[i])); - PADDLE_ENFORCE_GE(axis[i], -x_rank, - platform::errors::InvalidArgument( - "axis[%d] should be in the " - "range [-D, D), where D is the dimensions of X and " - "D is %d. But received axis[%d] = %d.", - i, x_rank, i, axis[i])); - if (axis[i] < 0) { - axis[i] += x_rank; - } - } - - bool keepdim = ctx->Attrs().Get("keepdim"); - bool reduce_all = ctx->Attrs().Get("reduce_all"); - auto dims_vector = vectorize(x_dims); - if (reduce_all) { - if (keepdim) - ctx->SetOutputDim("Out", - phi::make_ddim(std::vector(x_rank, 1))); - else - ctx->SetOutputDim("Out", {1}); - } else { - auto dims_vector = vectorize(x_dims); - if (keepdim) { - for (size_t i = 0; i < axis.size(); ++i) { - dims_vector[axis[i]] = 1; - } - } else { - const int kDelFlag = -1; - for (size_t i = 0; i < axis.size(); ++i) { - dims_vector[axis[i]] = kDelFlag; - } - dims_vector.erase( - std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - } - if (!keepdim && dims_vector.size() == 0) { - dims_vector.push_back(1); - } - auto out_dims = phi::make_ddim(dims_vector); - ctx->SetOutputDim("Out", out_dims); - if (axis.size() > 0 && axis[0] != 0) { - // Only pass LoD when not reducing on the first dim. - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - } }; class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker { @@ -164,16 +93,10 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; - +DECLARE_INFER_SHAPE_FUNCTOR(logsumexp, LogsumexpInferShapeFunctor, + PD_INFER_META(phi::LogsumexpInferMeta)); REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker, ops::LogsumexpGradOpMaker, - ops::LogsumexpGradOpMaker); + ops::LogsumexpGradOpMaker, + LogsumexpInferShapeFunctor); REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp); - -REGISTER_OP_CPU_KERNEL( - logsumexp, ops::LogsumexpKernel, - ops::LogsumexpKernel); -REGISTER_OP_CPU_KERNEL( - logsumexp_grad, - ops::LogsumexpGradKernel, - ops::LogsumexpGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.h b/paddle/fluid/operators/reduce_ops/logsumexp_op.h deleted file mode 100644 index 4490f08b2129ad0a1dfcd42602ce1ad6f694d1f7..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.h +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" - -namespace paddle { -namespace operators { - -#define HANDLE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - paddle::operators::ReduceFunctor( \ - context.template device_context(), *input, output, \ - axis, keepdim); \ - } - -struct LogsumexpFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - auto x_dim = x->dimensions(); - auto t_dim = x_dim; - for (int i = 0; i < static_cast(dim.size()); i++) { - t_dim[dim[i]] = 1; - } - - auto r_dim = x_dim; - for (int i = 0; i < static_cast(r_dim.size()); i++) { - r_dim[i] = 1; - } - for (int i = 0; i < static_cast(dim.size()); i++) { - r_dim[dim[i]] = x_dim[dim[i]]; - } - - auto y_dim = y->dimensions(); - auto x_max = x->maximum(dim); - y->device(place) = - (x_max + - (*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) - .reshape(y_dim); - } -}; - -struct LogsumexpGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); - } -}; - -template -class LogsumexpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - - auto axis = context.Attr>("axis"); - auto keepdim = context.Attr("keepdim"); - auto reduce_all = context.Attr("reduce_all"); - - const auto& input_dim_size = input->dims().size(); - // The dims has full dim, set the reduce_all is True - reduce_all |= (static_cast(axis.size()) == input_dim_size); - - if (reduce_all) { - // Flatten and reduce 1-D tensor - auto x = EigenVector::Flatten(*input); - auto out = EigenScalar::From(*output); - auto& place = - *context.template device_context().eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - LogsumexpFunctor()(place, &x, &out, reduce_dim); - } else { - int ndim = input_dim_size; - int rdim = axis.size(); - // comments for accelerating compiling temporarily. - // HANDLE_DIM(6, 5); - // HANDLE_DIM(6, 4); - // HANDLE_DIM(6, 3); - // HANDLE_DIM(6, 2); - // HANDLE_DIM(6, 1); - // HANDLE_DIM(5, 4); - // HANDLE_DIM(5, 3); - // HANDLE_DIM(5, 2); - // HANDLE_DIM(5, 1); - HANDLE_DIM(4, 3); - HANDLE_DIM(4, 2); - HANDLE_DIM(4, 1); - HANDLE_DIM(3, 2); - HANDLE_DIM(3, 1); - HANDLE_DIM(2, 1); - } - } -}; - -template -class LogsumexpGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Input("Out"); - auto* output_grad = context.Input(framework::GradVarName("Out")); - auto* input_grad = context.Output(framework::GradVarName("X")); - input_grad->mutable_data(context.GetPlace()); - - auto axis = context.Attr>("axis"); - auto reduce_all = context.Attr("reduce_all"); - const auto input_dim_size = context.Input("X")->dims().size(); - reduce_all |= (static_cast(axis.size()) == input_dim_size); - - if (reduce_all) { - auto x = EigenVector::Flatten(*input); - auto y = EigenVector::Flatten(*output); - auto dy = EigenVector::Flatten(*output_grad); - auto dx = EigenVector::Flatten(*input_grad); - auto& place = - *context.template device_context().eigen_device(); - auto broadcast_dim = - Eigen::array({{static_cast(input->numel())}}); - LogsumexpGradFunctor()(place, &x, &y, &dx, &dy, broadcast_dim, - broadcast_dim[0]); - } else { - int rank = input->dims().size(); - LogsumexpGradFunctor functor; - switch (rank) { - case 1: - ReduceGradFunctor( - context.template device_context(), *input, *output, - *output_grad, input_grad, functor, axis); - break; - case 2: - ReduceGradFunctor( - context.template device_context(), *input, *output, - *output_grad, input_grad, functor, axis); - break; - case 3: - ReduceGradFunctor( - context.template device_context(), *input, *output, - *output_grad, input_grad, functor, axis); - break; - case 4: - ReduceGradFunctor( - context.template device_context(), *input, *output, - *output_grad, input_grad, functor, axis); - break; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc index dcb849de0991bc93514a33d8ba93a6fe84d87093..6fb60fa1791571657d932edbec0b697544d3045e 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc @@ -14,7 +14,7 @@ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device_context.h" diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9761f902c70adce29a591d109f5ee4fed9ee8bab..c5cc845625479dd76cf13bb82f96dae221af6ecc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -804,6 +804,91 @@ void KthvalueInferMeta(const MetaTensor& x, indices->set_dtype(x.dtype()); } +void LogsumexpInferMeta(const MetaTensor& input, + const std::vector& axis, + bool keepdim, + bool reduce_all, + MetaTensor* out) { + auto x_dims = input.dims(); + auto x_rank = x_dims.size(); + std::vector formated_axis = axis; + PADDLE_ENFORCE_LE(x_rank, + 4, + errors::InvalidArgument( + "The input tensor X's dimensions of logsumexp " + "should be less or equal than 4. But received X's " + "dimensions = %d, X's shape = [%s].", + x_rank, + x_dims)); + PADDLE_ENFORCE_GT( + axis.size(), + 0, + errors::InvalidArgument( + "The size of axis of logsumexp " + "should be greater than 0. But received the size of axis " + "of logsumexp is %d.", + axis.size())); + + for (size_t i = 0; i < axis.size(); i++) { + PADDLE_ENFORCE_LT(axis[i], + x_rank, + errors::InvalidArgument( + "axis[%d] should be in the " + "range [-D, D), where D is the dimensions of X and " + "D is %d. But received axis[%d] = %d.", + i, + x_rank, + i, + axis[i])); + PADDLE_ENFORCE_GE(axis[i], + -x_rank, + errors::InvalidArgument( + "axis[%d] should be in the " + "range [-D, D), where D is the dimensions of X and " + "D is %d. But received axis[%d] = %d.", + i, + x_rank, + i, + axis[i])); + if (axis[i] < 0) { + formated_axis[i] += x_rank; + } + } + + auto dims_vector = vectorize(x_dims); + if (reduce_all) { + if (keepdim) + out->set_dims(phi::make_ddim(std::vector(x_rank, 1))); + else + out->set_dims({1}); + } else { + auto dims_vector = vectorize(x_dims); + if (keepdim) { + for (size_t i = 0; i < formated_axis.size(); ++i) { + dims_vector[formated_axis[i]] = 1; + } + } else { + const int kDelFlag = -1; + for (size_t i = 0; i < formated_axis.size(); ++i) { + dims_vector[formated_axis[i]] = kDelFlag; + } + dims_vector.erase( + std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + if (!keepdim && dims_vector.size() == 0) { + dims_vector.push_back(1); + } + auto out_dims = phi::make_ddim(dims_vector); + out->set_dims(out_dims); + if (formated_axis.size() > 0 && formated_axis[0] != 0) { + // Only pass LoD when not reducing on the first dim. + out->share_lod(input); + } + } + out->set_dtype(input.dtype()); +} + void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { auto dims = x.dims(); auto n_dim = dims.size(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 8a9876b11625c5aeaafac744496ff4240ec8cde0..3b6a34cff610dc4db85a906456842b467a851783 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -136,6 +136,12 @@ void KthvalueInferMeta(const MetaTensor& x, MetaTensor* indices, MetaConfig = MetaConfig()); +void LogsumexpInferMeta(const MetaTensor& input, + const std::vector& axis, + bool keepdim, + bool reduce_all, + MetaTensor* out); + void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); void MaxOutInferMeta(const MetaTensor& x, diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu b/paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc similarity index 58% rename from paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu rename to paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc index 81124e4f070a54444f4305dc903280548ac10b60..e0ef67084b445ea354e002ee1df836065afcc78f 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu +++ b/paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// .part used to speed up nvcc compile -#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" +#include "paddle/phi/kernels/logsumexp_grad_kernel.h" -namespace ops = paddle::operators; +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" -REGISTER_OP_CUDA_KERNEL( - logsumexp_grad, - ops::LogsumexpGradKernel, - ops::LogsumexpGradKernel); +PD_REGISTER_KERNEL( + logsumexp_grad, CPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {} diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.cu b/paddle/phi/kernels/cpu/logsumexp_kernel.cc similarity index 61% rename from paddle/fluid/operators/reduce_ops/logsumexp_op.cu rename to paddle/phi/kernels/cpu/logsumexp_kernel.cc index 86a31595ebaabcbc07fab64779c33566d5b020eb..06e0b30a9ca6567d04f7946a1732fd1483289e03 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.cu +++ b/paddle/phi/kernels/cpu/logsumexp_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" +#include "paddle/phi/kernels/logsumexp_kernel.h" -namespace ops = paddle::operators; +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -REGISTER_OP_CUDA_KERNEL( - logsumexp, ops::LogsumexpKernel, - ops::LogsumexpKernel); +#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" + +PD_REGISTER_KERNEL( + logsumexp, CPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..490b3e94045612cea551db515d273c7a0ef2c577 --- /dev/null +++ b/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/logsumexp_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/logsumexp_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0f07a39ab113ae42ff0bd634c92f571979f51127 --- /dev/null +++ b/paddle/phi/kernels/gpu/logsumexp_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/logsumexp_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" + +PD_REGISTER_KERNEL( + logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..c2583ce8d32df7aedbc9022f1eac0b85e9d7d082 --- /dev/null +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -0,0 +1,91 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" +#include "paddle/phi/kernels/logsumexp_grad_kernel.h" + +namespace phi { + +struct LogsumexpGradFunctor { + template + void operator()(const Context& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); + } +}; + +template +void LogsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& in, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& axis, + bool keepdim, + bool reduce_all, + DenseTensor* in_grad) { + dev_ctx.template Alloc(in_grad); + + const auto input_dim_size = in.dims().size(); + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + auto x = phi::EigenVector::Flatten(in); + auto y = phi::EigenVector::Flatten(out); + auto dy = phi::EigenVector::Flatten(out_grad); + auto dx = phi::EigenVector::Flatten(*in_grad); + auto& place = *dev_ctx.eigen_device(); + auto broadcast_dim = Eigen::array({{static_cast(in.numel())}}); + LogsumexpGradFunctor()( + place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); + } else { + int rank = in.dims().size(); + LogsumexpGradFunctor functor; + switch (rank) { + case 1: + phi::funcs::ReduceGradFunctor( + dev_ctx, in, out, out_grad, in_grad, functor, axis); + break; + case 2: + phi::funcs::ReduceGradFunctor( + dev_ctx, in, out, out_grad, in_grad, functor, axis); + break; + case 3: + phi::funcs::ReduceGradFunctor( + dev_ctx, in, out, out_grad, in_grad, functor, axis); + break; + case 4: + phi::funcs::ReduceGradFunctor( + dev_ctx, in, out, out_grad, in_grad, functor, axis); + break; + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7a9573ff522b0a2f3c9cc62e39054c434b55282d --- /dev/null +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/logsumexp_kernel.h" + +namespace phi { + +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceFunctor( \ + dev_ctx, x, out, axis, keepdim); \ + } + +struct LogsumexpFunctor { + template + void operator()(const Context& place, X* x, Y* y, const Dim& dim) { + auto x_dim = x->dimensions(); + auto t_dim = x_dim; + for (int i = 0; i < static_cast(dim.size()); i++) { + t_dim[dim[i]] = 1; + } + + auto r_dim = x_dim; + for (int i = 0; i < static_cast(r_dim.size()); i++) { + r_dim[i] = 1; + } + for (int i = 0; i < static_cast(dim.size()); i++) { + r_dim[dim[i]] = x_dim[dim[i]]; + } + + auto y_dim = y->dimensions(); + auto x_max = x->maximum(dim); + y->device(place) = + (x_max + + (*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) + .reshape(y_dim); + } +}; + +template +void LogsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + bool keepdim, + bool reduce_all, + DenseTensor* out) { + dev_ctx.template Alloc(out); + + const auto& input_dim_size = x.dims().size(); + // The dims has full dim, set the reduce_all is True + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto input = phi::EigenVector::Flatten(x); + auto output = phi::EigenScalar::From(*out); + auto& place = *dev_ctx.eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + LogsumexpFunctor()(place, &input, &output, reduce_dim); + } else { + int ndim = input_dim_size; + int rdim = axis.size(); + // comments for accelerating compiling temporarily. + // HANDLE_DIM(6, 5); + // HANDLE_DIM(6, 4); + // HANDLE_DIM(6, 3); + // HANDLE_DIM(6, 2); + // HANDLE_DIM(6, 1); + // HANDLE_DIM(5, 4); + // HANDLE_DIM(5, 3); + // HANDLE_DIM(5, 2); + // HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/logsumexp_grad_kernel.h b/paddle/phi/kernels/logsumexp_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d68c447aa65cb2970e4f8897042f81acffce454f --- /dev/null +++ b/paddle/phi/kernels/logsumexp_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogsumexpGradKernel(const Context& ctx, + const DenseTensor& in, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::vector& axis, + bool keepdim, + bool reduce_all, + DenseTensor* in_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/logsumexp_kernel.h b/paddle/phi/kernels/logsumexp_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ba1b18230fa52845bede77331cadd8d61bbd5244 --- /dev/null +++ b/paddle/phi/kernels/logsumexp_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogsumexpKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& axis, + bool keepdim, + bool reduce_all, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/logsumexp_sig.cc b/paddle/phi/ops/compat/logsumexp_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca7345dbe704999183a784489f13bea05e30fdc0 --- /dev/null +++ b/paddle/phi/ops/compat/logsumexp_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LogsumexpGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logsumexp_grad", + {"X", "Out", GradVarName("Out")}, + {"axis", "keepdim", "reduce_all"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(logsumexp_grad, phi::LogsumexpGradOpArgumentMapping);