// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/nanmedian_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h" namespace phi { template void CalcMedianGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& median_index, const DenseTensor& out_grad, DenseTensor* x_grad) { T* dx_data = dev_ctx.template Alloc(x_grad); if (!dx_data) return; phi::funcs::SetConstant set_zero; set_zero(dev_ctx, x_grad, static_cast(0)); const int64_t* m_data = median_index.data(); const T* dout_data = out_grad.data(); int64_t numel = x.numel(); auto x_dim = x.dims(); int64_t rank = x_dim.size(); int64_t stride = x_dim[rank - 1]; int64_t pre_dim = numel / stride; int64_t i = 0; int64_t offset = 0; for (i = 0; i < pre_dim; i++) { if (m_data[2 * i] >= 0) { if (m_data[2 * i] == m_data[2 * i + 1]) { dx_data[offset + m_data[2 * i]] = dout_data[i]; } else { dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast(2.0); dx_data[offset + m_data[2 * i + 1]] = dout_data[i] / static_cast(2.0); } } offset += stride; } } template void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& median_index, const DenseTensor& out_grad, const IntArray& axes, bool keepdim UNUSED, DenseTensor* x_grad) { DenseTensor tmp_x; auto rank = x.dims().size(); if ((axes.size() == 0) || rank <= 1) { tmp_x = x; tmp_x.Resize({x.numel()}); CalcMedianGradKernel( dev_ctx, tmp_x, median_index, out_grad, x_grad); } else { funcs::PreprocessMedianKernel(dev_ctx, x, axes, &tmp_x); DenseTensor tmp_x_grad; tmp_x_grad.Resize(x_grad->dims()); CalcMedianGradKernel( dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad); dev_ctx.template Alloc(x_grad); funcs::PostprocessMedianGradKernel( dev_ctx, &tmp_x_grad, axes, x_grad); } } } // namespace phi PD_REGISTER_KERNEL(nanmedian_grad, CPU, ALL_LAYOUT, phi::NanmedianGradKernel, float, double, int, int64_t) {}