/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/pten/kernels/funcs/complex_functors.h" // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/kernels/matmul_grad_kernel.h" #include "paddle/pten/kernels/matmul_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif namespace paddle { namespace operators { // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. static framework::Tensor FoldInitDims(const framework::Tensor& input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); } return output; } /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. */ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { if (x_dim.size() > 1) { return x_dim; } return framework::make_ddim({1, x_dim[0]}); } /** * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the * original y_dim is returned. */ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { if (y_dim.size() > 1) { return y_dim; } return framework::make_ddim({y_dim[0], 1}); } /** * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. * * The shape would be [BatchSize, H, W] or [H, W]. * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( framework::Tensor* x, const math::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; if (descriptor.trans_) { std::swap(w, h); } if (descriptor.batch_size_) { x->Resize({descriptor.batch_size_, h, w}); } else { x->Resize({h, w}); } } static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, framework::Tensor* y, framework::Tensor* out, bool trans_x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), mat_dim_x.height_, mat_dim_y.width_}); } ReshapeTensorIntoMatrixSequence(x, mat_dim_x); ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } } // namespace operators } // namespace paddle