// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/reduce_grad_functions.h" #include "paddle/phi/kernels/logsumexp_grad_kernel.h" namespace phi { struct LogsumexpGradFunctor { template void operator()(const Context& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); } }; template void LogsumexpGradKernel(const Context& dev_ctx, const DenseTensor& in, const DenseTensor& out, const DenseTensor& out_grad, const std::vector& axis, bool keepdim, bool reduce_all, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); const auto input_dim_size = in.dims().size(); reduce_all |= (static_cast(axis.size()) == input_dim_size); if (reduce_all) { auto x = phi::EigenVector::Flatten(in); auto y = phi::EigenVector::Flatten(out); auto dy = phi::EigenVector::Flatten(out_grad); auto dx = phi::EigenVector::Flatten(*in_grad); auto& place = *dev_ctx.eigen_device(); auto broadcast_dim = Eigen::array({{static_cast(in.numel())}}); LogsumexpGradFunctor()( place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); } else { int rank = in.dims().size(); LogsumexpGradFunctor functor; std::vector axis32; axis32.reserve(axis.size()); std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) { axis32.push_back(t); }); switch (rank) { case 1: phi::funcs::ReduceGradFunctor( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 2: phi::funcs::ReduceGradFunctor( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 3: phi::funcs::ReduceGradFunctor( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 4: phi::funcs::ReduceGradFunctor( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; } } } } // namespace phi