// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/diag_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" namespace phi { template static DenseTensor Fill(const Context& ctx, std::vector shape, float fill_value) { DenseTensor ret; ret.Resize(make_ddim(shape)); ctx.template Alloc(&ret); funcs::SetConstant()(ctx, &ret, T(fill_value)); return ret; } template static DenseTensor Eye(const Context& dev_ctx, int n) { auto output = Fill(dev_ctx, {n}, 1); auto ret = Diag(dev_ctx, output, 0, 0); return ret; } template static DenseTensor Infinits(const Context& ctx, std::vector shape) { auto value = static_cast(std::numeric_limits::infinity()); return Fill(ctx, shape, value); } static DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) { // don't copy data, only change the dims DenseTensor out; out.ShareDataWith(x); std::vector out_shape = phi::vectorize(x.dims()); if (axis >= 0) { auto index = (out_shape.begin() + axis); out_shape.insert(index, 1); } else if (axis < 0) { auto index = (out_shape.end() + axis + 1); out_shape.insert(index, 1); } out.Resize(phi::make_ddim(out_shape)); return out; } template void SvdGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& u, const DenseTensor& vh, const DenseTensor& s, const DenseTensor& u_grad, const DenseTensor& vh_grad, const DenseTensor& s_grad, bool full_matrices, DenseTensor* x_grad) { const auto& dX = *x_grad; int m = dX.dims()[dX.dims().size() - 2]; int n = dX.dims()[dX.dims().size() - 1]; int k = s.dims()[s.dims().size() - 1]; DenseTensor U, VH, dU, dV, dVH; if (full_matrices) { // if full_matrices is set, slice the U and VT to k columns U = SliceKernel( dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {}); VH = SliceKernel( dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); dU = SliceKernel( dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {}); dVH = SliceKernel( dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); } else { U = u; VH = vh; dU = u_grad; dVH = vh_grad; } auto s_inverse = Pow(dev_ctx, s, -1); auto s_square = Pow(dev_ctx, s, 2); auto F = Subtract( dev_ctx, Unsqueeze(s_square, -2), Unsqueeze(s_square, -1)); F = Add( dev_ctx, F, Diag(dev_ctx, Infinits(dev_ctx, {k}), 0, 0)); F = Pow(dev_ctx, F, -1); DenseTensor sigma_term; DenseTensor u_term; DenseTensor v_term; // if (ctx.HasInput(framework::GradVarName("S"))) { const DenseTensor& gS = s_grad; sigma_term = Multiply(dev_ctx, Unsqueeze(gS, -2), U); sigma_term = Matmul(dev_ctx, sigma_term, VH); } // if (ctx.HasInput(framework::GradVarName("U"))) { { auto UTG = Matmul(dev_ctx, U, dU, true, false); auto GTU = Matmul(dev_ctx, dU, U, true, false); u_term = Multiply( dev_ctx, Multiply( dev_ctx, Subtract(dev_ctx, UTG, GTU), F), Unsqueeze(s, -2)); u_term = Matmul(dev_ctx, U, u_term); if (m > k) { auto project = Subtract(dev_ctx, Eye(dev_ctx, m), Matmul(dev_ctx, U, U, false, true)); u_term = Add( dev_ctx, u_term, Multiply(dev_ctx, Matmul(dev_ctx, project, dU), Unsqueeze(s_inverse, -2))); } u_term = Matmul(dev_ctx, u_term, VH); } // } // if (ctx.HasInput(framework::GradVarName("VH"))) { { auto UTG = Matmul(dev_ctx, VH, dVH, false, true); auto GTU = Matmul(dev_ctx, dVH, VH, false, true); v_term = Multiply( dev_ctx, Matmul( dev_ctx, Multiply( dev_ctx, Subtract(dev_ctx, UTG, GTU), F), VH), Unsqueeze(s, -1)); if (n > k) { auto project = Subtract( dev_ctx, Eye(dev_ctx, n), Matmul(dev_ctx, VH, VH, true, false)); v_term = Add( dev_ctx, v_term, Multiply(dev_ctx, Matmul(dev_ctx, dVH, project), Unsqueeze(s_inverse, -1))); } v_term = Matmul(dev_ctx, U, v_term); } *x_grad = Add( dev_ctx, Add(dev_ctx, u_term, sigma_term), v_term); } } // namespace phi