// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { template void BilinearGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& weight, const DenseTensor& dout, DenseTensor* dx, DenseTensor* dy, DenseTensor* dweight, DenseTensor* dbias) { auto batch_size = x.dims()[0]; auto weight_dims = weight.dims(); int out_dim = weight_dims[0]; auto x_dim = weight_dims[1]; auto y_dim = weight_dims[2]; auto x_mat = EigenMatrix::From(x); auto y_mat = EigenMatrix::From(y); auto dout_mat = EigenMatrix::From(dout); auto& place = *ctx.eigen_device(); // Create the intermediate variable to calculate the Output(Y@Grad). DenseTensor x_scale; x_scale.Resize(make_ddim({batch_size, x_dim})); ctx.template Alloc(&x_scale); auto x_scale_mat = EigenMatrix::From(x_scale); // Create the intermediate variable to calculate the Output(X@Grad). DenseTensor y_scale; y_scale.Resize(make_ddim({batch_size, y_dim})); ctx.template Alloc(&y_scale); auto y_scale_mat = EigenMatrix::From(y_scale); funcs::SetConstant set_zero; if (dx) { ctx.template Alloc(dx); set_zero(ctx, dx, static_cast(0)); } if (dy) { ctx.template Alloc(dy); set_zero(ctx, dy, static_cast(0)); } if (dweight) { ctx.template Alloc(dweight); } auto blas = funcs::GetBlas(ctx); // Caculate the Output(X@Grad) and Output(Y@Grad). if (dx || dy || dweight) { Eigen::DSizes bcast_for_x(1, y_dim); Eigen::DSizes bcast_for_y(1, x_dim); Eigen::DSizes bcast_for_weight(1, x_dim); for (int i = 0; i < out_dim; ++i) { DenseTensor weight_i = weight.Slice(i, i + 1).Resize(make_ddim({x_dim, y_dim})); auto output_vec = dout_mat.chip(i, 1); if (dx) { y_scale_mat.device(place) = output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_x) * y_mat; blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, y_scale.data(), weight_i.data(), 1, dx->data()); } if (dy || dweight) { auto output_vec_y = output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_y); x_scale_mat.device(place) = output_vec_y * x_mat; if (dy) { blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x_scale.data(), weight_i.data(), 1, dy->data()); } if (dweight) { DenseTensor dweight_i = dweight->Slice(i, i + 1).Resize(make_ddim({x_dim, y_dim})); blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1, x_scale.data(), y.data(), 0, dweight_i.data()); } } } } // calculate the gradient of Input(Bias). if (dbias) { ctx.template Alloc(dbias); auto dbias_mat = EigenVector::Flatten(*dbias); dbias_mat.device(place) = dout_mat.sum(Eigen::DSizes(0)); } } } // namespace phi