/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/kernels/matmul_grad_kernel.h" #include "paddle/pten/kernels/matmul_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif namespace paddle { namespace operators { using framework::Tensor; template class MatMulV2Kernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { auto* X = ctx.Input("X"); auto* Y = ctx.Input("Y"); auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); auto& dev_ctx = ctx.device_context(); Out->mutable_data(X->place()); // call new kernel pten::MatmulKernel(dev_ctx, *X, *Y, trans_x, trans_y, Out); } }; // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. static framework::Tensor FoldInitDims(const framework::Tensor& input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); } return output; } /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. */ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { if (x_dim.size() > 1) { return x_dim; } return framework::make_ddim({1, x_dim[0]}); } /** * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the * original y_dim is returned. */ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { if (y_dim.size() > 1) { return y_dim; } return framework::make_ddim({y_dim[0], 1}); } /** * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. * * The shape would be [BatchSize, H, W] or [H, W]. * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( framework::Tensor* x, const math::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; if (descriptor.trans_) { std::swap(w, h); } if (descriptor.batch_size_) { x->Resize({descriptor.batch_size_, h, w}); } else { x->Resize({h, w}); } } static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, framework::Tensor* y, framework::Tensor* out, bool trans_x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), mat_dim_x.height_, mat_dim_y.width_}); } ReshapeTensorIntoMatrixSequence(x, mat_dim_x); ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } template class MatMulV2GradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { bool transpose_x = ctx.Attr("trans_x"); bool transpose_y = ctx.Attr("trans_y"); auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); if (dx) dx->mutable_data(ctx.GetPlace()); if (dy) dy->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.device_context(); // call new kernel pten::MatmulGradKernel(dev_ctx, *x, *y, *dout, transpose_x, transpose_y, dx, dy); } }; template class MatMulV2DoubleGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* dout = context.Input("DOut"); auto* ddx = context.Input("DDX"); auto* ddy = context.Input("DDY"); auto* dx = context.Output("DX"); auto* dy = context.Output("DY"); auto* ddout = context.Output("DDOut"); bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); if (dx) dx->mutable_data(context.GetPlace()); if (dy) dy->mutable_data(context.GetPlace()); if (ddout) ddout->mutable_data(context.GetPlace()); auto& dev_ctx = context.device_context(); // call new kernel pten::MatmulDoubleGradKernel(dev_ctx, *x, *y, *dout, *ddx, *ddy, transpose_x, transpose_y, dx, dy, ddout); } }; template class MatMulV2TripleGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { // get input auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* dout = context.Input("DOut"); auto* ddx = context.Input("DDX"); auto* ddy = context.Input("DDY"); auto* d_dx = context.Input("D_DX"); auto* d_dy = context.Input("D_DY"); auto* d_ddout = context.Input("D_DDOut"); // get output auto* out_d_x = context.Output("D_X_out"); auto* out_d_y = context.Output("D_Y_out"); auto* out_d_dout = context.Output("D_DOut_out"); auto* out_d_ddx = context.Output("D_DDX_out"); auto* out_d_ddy = context.Output("D_DDY_out"); bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); if (out_d_x) out_d_x->mutable_data(context.GetPlace()); if (out_d_y) out_d_y->mutable_data(context.GetPlace()); if (out_d_dout) out_d_dout->mutable_data(context.GetPlace()); if (out_d_ddx) out_d_ddx->mutable_data(context.GetPlace()); if (out_d_ddy) out_d_ddy->mutable_data(context.GetPlace()); auto& dev_ctx = context.device_context(); // call new kernel pten::MatmulTripleGradKernel( dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x, transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy); } }; } // namespace operators } // namespace paddle