From 4e21457d74c36da0cbc5f764100397a02437bd86 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Thu, 30 Dec 2021 20:22:23 +0800 Subject: [PATCH] add OP lu forward (#38559) LGTM --- cmake/operators.cmake | 1 + paddle/fluid/operators/lu_op.cc | 162 ++++++ paddle/fluid/operators/lu_op.cu | 156 ++++++ paddle/fluid/operators/lu_op.h | 474 ++++++++++++++++++ paddle/fluid/platform/dynload/cusolver.h | 8 + .../fluid/tests/unittests/test_lu_op.py | 171 +++++++ tools/static_mode_white_list.py | 1 + 7 files changed, 973 insertions(+) create mode 100644 paddle/fluid/operators/lu_op.cc create mode 100644 paddle/fluid/operators/lu_op.cu create mode 100644 paddle/fluid/operators/lu_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_lu_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index ef25675d7dc..2d1ce4e8342 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -197,6 +197,7 @@ function(op_library TARGET) list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu") + list(REMOVE_ITEM hip_srcs "lu_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu") diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc new file mode 100644 index 00000000000..d3997f848e0 --- /dev/null +++ b/paddle/fluid/operators/lu_op.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lu_op.h" + +namespace paddle { +namespace operators { + +class LUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddComment(R"DOC(LU decomposition, + Computes the LU factorization of a matrix or batches of matrices A. + )DOC"); + AddInput("X", "(Tensor) The input tensor, shape of (*,m,n)"); + AddOutput("Out", "(Tensor) The output tensor, shape same to X"); + AddOutput("Pivots", + "Stores all the intermediate transpositions of rows. shape of " + "(*,min(m,n))"); + AddOutput("Infos", + "(Tensor) This is a tensor of size (*) where non-zero values " + "indicate whether factorization for the matrix has succeeded"); + AddAttr("pivots", "Whether pivoting is done").SetDefault(true); + } +}; + +class LUOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU"); + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "LU"); + bool pivots = context->Attrs().Get("pivots"); + auto x_dims = context->GetInputDim("X"); + int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE(x_rank, 2, platform::errors::InvalidArgument( + "the rank of input must greater than 2")); + context->SetOutputDim("Out", x_dims); + int m = x_dims[x_rank - 1]; + int n = x_dims[x_rank - 2]; + int min_mn = std::min(m, n); + auto dims_vec = framework::vectorize(x_dims); + OP_INOUT_CHECK(context->HasOutput("Infos"), "Output", "Infos", "LU"); + if (x_rank == 2) { + auto Infos_dim = std::vector(1); + context->SetOutputDim("Infos", framework::make_ddim(Infos_dim)); + } else { + auto Infos_dim = + std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 2); + context->SetOutputDim("Infos", framework::make_ddim(Infos_dim)); + } + if (pivots) { + OP_INOUT_CHECK(context->HasOutput("Pivots"), "Output", "Pivots", "LU"); + auto Pivots_dim = + std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 1); + Pivots_dim[x_rank - 2] = min_mn; + context->SetOutputDim("Pivots", framework::make_ddim(Pivots_dim)); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class LUOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_type = ctx->GetInputType("X", 0); + auto data_type = ctx->GetInputDataType("X", 0); + + ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); + + ctx->SetOutputType("Pivots", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Pivots", framework::proto::VarType::INT32, + framework::ALL_ELEMENTS); + + ctx->SetOutputType("Infos", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Infos", framework::proto::VarType::INT32, + framework::ALL_ELEMENTS); + } +}; + +template +class LUKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto pivots = ctx.Attr("pivots"); + auto *xin = ctx.Input("X"); + auto *out = ctx.Output("Out"); + auto *IpivT = ctx.Output("Pivots"); + auto *InfoT = ctx.Output("Infos"); + PADDLE_ENFORCE_EQ(pivots, true, + platform::errors::InvalidArgument( + "lu without pivoting is not implemented on the CPU, " + "but got pivots=False")); + + math::DeviceIndependenceTensorOperations + helper(ctx); + *out = helper.Transpose(*xin); + + auto outdims = out->dims(); + auto outrank = outdims.size(); + + int m = static_cast(outdims[outrank - 1]); + int n = static_cast(outdims[outrank - 2]); + int lda = std::max(1, m); + + auto ipiv_dims = slice_ddim(outdims, 0, outrank - 1); + ipiv_dims[outrank - 2] = std::min(m, n); + IpivT->Resize(ipiv_dims); + auto ipiv_data = IpivT->mutable_data(ctx.GetPlace()); + + auto info_dims = slice_ddim(outdims, 0, outrank - 2); + if (info_dims.size() == 0) { + info_dims = framework::make_ddim({1}); + } + InfoT->Resize(info_dims); + auto info_data = InfoT->mutable_data(ctx.GetPlace()); + + auto batchsize = product(info_dims); + batchsize = std::max(static_cast(batchsize), 1); + auto out_data = out->mutable_data(ctx.GetPlace()); + for (int b = 0; b < batchsize; b++) { + auto out_data_item = &out_data[b * m * n]; + int *info_data_item = &info_data[b]; + int *ipiv_data_item = &ipiv_data[b * std::min(m, n)]; + math::lapackLu(m, n, out_data_item, lda, ipiv_data_item, + info_data_item); + } + *out = helper.Transpose(*out); + } +}; + +DECLARE_INPLACE_OP_INFERER(LUOpInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(lu, ops::LUOp, ops::LUOpMaker, ops::LUOpVarTypeInference, + ops::LUOpInplaceInferer); + +REGISTER_OP_CPU_KERNEL(lu, ops::LUKernel, ops::LUKernel); diff --git a/paddle/fluid/operators/lu_op.cu b/paddle/fluid/operators/lu_op.cu new file mode 100644 index 00000000000..bd6dc712463 --- /dev/null +++ b/paddle/fluid/operators/lu_op.cu @@ -0,0 +1,156 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/lu_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, int m, int n, + T* d_A, int lda, int* lwork); +template +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, int m, int n, T* d_A, + int lda, T* d_work, int* d_Ipiv, int* d_info); + +template <> +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, int m, + int n, float* d_A, int lda, int* lwork) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf_bufferSize( + cusolverH, m, n, d_A, lda, lwork)); +} + +template <> +void cusolver_bufferSize(const cusolverDnHandle_t& cusolverH, int m, + int n, double* d_A, int lda, int* lwork) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf_bufferSize( + cusolverH, m, n, d_A, lda, lwork)); +} + +template <> +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, int m, int n, + float* d_A, int lda, float* d_work, int* d_Ipiv, + int* d_info) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgetrf( + cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); +} + +template <> +void cusolver_getrf(const cusolverDnHandle_t& cusolverH, int m, int n, + double* d_A, int lda, double* d_work, int* d_Ipiv, + int* d_info) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgetrf( + cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info)); +} + +template +void lu_decomposed_kernel(int m, int n, T* d_A, int lda, int* d_Ipiv, + int* d_info, const framework::ExecutionContext& ctx) { + /* step 1: get cusolver handle*/ + auto& dev_ctx = ctx.template device_context(); + auto cusolverH = dev_ctx.cusolver_dn_handle(); + + /* step 2: query working space of getrf */ + int lwork; + cusolver_bufferSize(cusolverH, m, n, d_A, lda, &lwork); + + auto work_buff = memory::Alloc(dev_ctx, lwork * sizeof(T)); + T* d_work = reinterpret_cast(work_buff->ptr()); + + /* step 3: LU factorization */ + if (d_Ipiv) { + cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, d_Ipiv, d_info); + } else { + cusolver_getrf(cusolverH, m, n, d_A, lda, d_work, NULL, d_info); + } + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +} + +template +class LUCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto* xin = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* IpivT = ctx.Output("Pivots"); + auto* InfoT = ctx.Output("Infos"); + auto pivots = ctx.Attr("pivots"); + + math::DeviceIndependenceTensorOperations< + paddle::platform::CUDADeviceContext, T> + helper(ctx); + *out = helper.Transpose(*xin); + + auto outdims = out->dims(); + auto outrank = outdims.size(); + + int m = static_cast(outdims[outrank - 1]); + int n = static_cast(outdims[outrank - 2]); + int lda = std::max(1, m); + if (pivots) { + auto ipiv_dims = slice_ddim(outdims, 0, outrank - 1); + ipiv_dims[outrank - 2] = std::min(m, n); + IpivT->Resize(ipiv_dims); + } + auto ipiv_data = IpivT->mutable_data(ctx.GetPlace()); + + auto info_dims = slice_ddim(outdims, 0, outrank - 2); + if (info_dims.size() == 0) { + info_dims = framework::make_ddim({1}); + } + InfoT->Resize(info_dims); + auto info_data = InfoT->mutable_data(ctx.GetPlace()); + + auto batchsize = product(info_dims); + batchsize = std::max(static_cast(batchsize), 1); + auto out_data = out->mutable_data(ctx.GetPlace()); + for (int b = 0; b < batchsize; b++) { + auto out_data_item = &out_data[b * m * n]; + int* info_data_item = &info_data[b]; + if (pivots) { + auto ipiv_data_item = &ipiv_data[b * std::min(m, n)]; + lu_decomposed_kernel(m, n, out_data_item, lda, ipiv_data_item, + info_data_item, ctx); + } else { + lu_decomposed_kernel(m, n, out_data_item, lda, NULL, info_data_item, + ctx); + } + } + *out = helper.Transpose(*out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(lu, ops::LUCUDAKernel, + ops::LUCUDAKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h new file mode 100644 index 00000000000..57cab052a25 --- /dev/null +++ b/paddle/fluid/operators/lu_op.h @@ -0,0 +1,474 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/operators/set_value_op.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/operators/triangular_solve_op.h" +#include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/pten/include/math.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensorArray = framework::LoDTensorArray; + +template +void SetValueCompute(const framework::ExecutionContext& ctx, + framework::Tensor* in, framework::Tensor* value_tensor, + framework::Tensor* out, const std::vector& axes, + std::vector* starts, std::vector* ends, + const std::vector& shape) { + std::vector steps = {1, 1}; + std::vector decrease_axes = {}; + std::vector none_axes = {}; + + auto dtype = in->type(); + + auto in_dims = in->dims(); + CheckAndUpdateSliceAttrs(in_dims, axes, starts, ends, &steps); + auto slice_dims = GetSliceDims(in_dims, axes, *starts, *ends, &steps); + auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); + + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; + + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + + slice_dims_for_assign = framework::make_ddim(slice_dims_with_none); + } + + auto place = ctx.GetPlace(); + auto& eigen_place = + *ctx.template device_context().eigen_device(); + + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + TensorCopy(*in, place, out); + + Tensor slice_tensor(dtype), pad_tensor(dtype); + slice_tensor.mutable_data(slice_dims, place); + pad_tensor.mutable_data(in_dims, place); + + auto pad_e = framework::EigenTensor::From(pad_tensor, in_dims); + auto out_e = framework::EigenTensor::From(*out); + auto slice_e = framework::EigenTensor::From(slice_tensor, slice_dims); + + // Step 1: Set the value of out at `_index` to zero + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = (*starts)[i]; + ends_indices[axis_index] = (*ends)[i]; + strides_indices[axis_index] = steps[i]; + if ((*starts)[i] == + (*ends)[i]) { // slice is empty, data will not be changed + return; + } + } + + out_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value_tensor, and data out of '_index' to zero + + // - Step 2.1 Set slice tensor with value + + // NOTE(liym27): [ Why resize slice_tensor here? ] + // A: When do broadcasting on slice_tensor and value_tensor, the shape of + // slice_tensor should be decreased dims. + // e.g. + // x[:,0] = value_tensor + // x's shape = [3, 4], value_tensor's shape = [3] + // We get slice_dims = [3, 1], decrease_slice_dims = [3] + // If do broadcasting on Tensor with shape [3, 1] and [3], the result's + // shape is [3, 3], which cross the border; + // If do broadcasting on Tensor with shape [3] and [3], the result's shape + // is [3], which is right. + + slice_tensor.Resize(slice_dims_for_assign); + if (value_tensor != nullptr) { + CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); + // ElementwiseComputeEx can do broadcasting + ElementwiseComputeEx, DeviceContext, T>( + ctx, &slice_tensor, value_tensor, -1, SubFunctor(), &slice_tensor); + } else { + Tensor value_t(dtype); + auto value_dims = framework::make_ddim(shape); + CheckIsDimsMatch(slice_dims_for_assign, value_dims); + + value_t.mutable_data(value_dims, place); + auto value_name = GetValueName(dtype); + CopyVecotorToTensor(value_name.c_str(), &value_t, ctx); + value_t.Resize(value_dims); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &slice_tensor, &value_t, -1, SubFunctor(), &slice_tensor); + } + slice_tensor.Resize(slice_dims); + + // - Step 2.2 Pad slice tensor with 0 + pad_e.device(eigen_place) = pad_e.constant(T(0)); + pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 3: Set out tensor with value_tensor + out_e.device(eigen_place) = out_e - pad_e; +} + +template +void SetValueCompute_dispatch( + const framework::ExecutionContext& ctx, framework::Tensor* in, + framework::Tensor* value_tensor, framework::Tensor* out, + const std::vector& axes, std::vector* starts, + std::vector* ends, const std::vector& shape, int rank) { + switch (rank) { + case 1: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + case 2: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + case 3: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + case 4: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + case 5: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + case 6: + SetValueCompute(ctx, in, value_tensor, out, axes, + starts, ends, shape); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void Tensor_Conj(const DeviceContext& dev_ctx, const framework::Tensor& tensor, + framework::Tensor* out) { + out->Resize(tensor.dims()); + platform::ForRange out_for_range(dev_ctx, tensor.numel()); + math::ConjFunctor out_functor(tensor.data(), tensor.numel(), + out->mutable_data(dev_ctx.GetPlace())); + out_for_range(out_functor); +} + +template +void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, + const framework::Tensor& src2, framework::Tensor* out) { + out->Resize(src1.dims()); + out->mutable_data(dev_ctx.GetPlace()); + auto pt_x = paddle::experimental::MakePtenDenseTensor(src1); + auto pt_y = paddle::experimental::MakePtenDenseTensor(src2); + auto pt_z = paddle::experimental::MakePtenDenseTensor(*out); + pten::Add(dev_ctx, *pt_x.get(), *pt_y.get(), -1, pt_z.get()); +} + +template +void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, + const framework::Tensor& src2, framework::Tensor* out) { + out->Resize(src1.dims()); + out->mutable_data(dev_ctx.GetPlace()); + auto pt_x = paddle::experimental::MakePtenDenseTensor(src1); + auto pt_y = paddle::experimental::MakePtenDenseTensor(src2); + auto pt_z = paddle::experimental::MakePtenDenseTensor(*out); + pten::Subtract(dev_ctx, *pt_x.get(), *pt_y.get(), -1, pt_z.get()); +} + +template +void SliceCompute(const framework::ExecutionContext& ctx, + const framework::Tensor* in, framework::Tensor* out, + const std::vector& axes_int, + const std::vector& starts_int, + const std::vector& ends_int) { + std::vector axes(axes_int.begin(), axes_int.end()); + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + + std::vector decrease_axis = {}; + std::vector infer_flags = {}; + + PADDLE_ENFORCE_EQ( + starts.size(), axes.size(), + platform::errors::InvalidArgument( + "The size of starts must be equal to the size of axes.")); + PADDLE_ENFORCE_EQ(ends.size(), axes.size(), + platform::errors::InvalidArgument( + "The size of ends must be equal to the size of axes.")); + + // Step 2: Compute output + + auto in_dims = in->dims(); + auto out_dims = out->dims(); + auto slice_dims = out_dims; + + // 2.1 Infer output dims + for (size_t i = 0; i < axes.size(); ++i) { + // when start == -1 && end == start+1 + if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + ends[i] = in_dims[axes[i]]; + } + } + } + + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = + GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = GetDecreasedDims(slice_dims, decrease_axis); + + // 2.2 Get output + auto offsets = Eigen::DSizes(); + auto extents = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; + } + + out->Resize(slice_dims); + out->mutable_data(ctx.GetPlace()); + + auto in_t = framework::EigenTensor::From(*in, in_dims); + auto out_t = framework::EigenTensor::From(*out, slice_dims); + auto& eigen_place = + *ctx.template device_context().eigen_device(); + + if (in->numel() <= Eigen::NumTraits::highest()) { + // similar to tf.slice: + // if element number less than INT_MAX, change the type of index to int + Eigen::DSizes offsets_32bit, extents_32bit; + for (size_t i = 0; i < D; i++) { + offsets_32bit[i] = offsets[i]; + extents_32bit[i] = extents[i]; + } + EigenSlice, T, D>::Eval( + eigen_place, framework::To32BitIndex(out_t), + framework::To32BitIndex(in_t), offsets_32bit, extents_32bit); + } else { + EigenSlice, T, D>::Eval( + eigen_place, out_t, in_t, offsets, extents); + } + + out->Resize(out_dims); + out->mutable_data(ctx.GetPlace()); +} + +template +void Tensor_narrow(const framework::ExecutionContext& ctx, + const framework::Tensor* src, framework::Tensor* out, + int row_s, int row_e, int col_s, int col_e) { + auto rank = src->dims().size(); + std::vector axes_int = {rank - 2, rank - 1}; + std::vector starts_int = {row_s, col_s}; + std::vector ends_int = {row_e, col_e}; + switch (rank) { + case 1: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + case 2: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + case 3: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + case 4: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + case 5: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + case 6: + SliceCompute(ctx, src, out, axes_int, starts_int, + ends_int); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void arange(const DeviceContext& dev_ctx, framework::Tensor* tmp, int w, + int batchsize = 1, int h = 1) { + tmp->Resize(framework::make_ddim({batchsize * w})); + platform::CPUPlace cpu; + auto tmpdata = tmp->mutable_data(cpu); + for (int b = 0; b < batchsize; b++) { + for (int i = 0; i < w; i++) { + tmpdata[b * w + i] = static_cast(b * h + i); + } + } +} + +template +struct OneFunctor { + OneFunctor(T* output, int* idtptr, int w, int dim) + : output_(output), idtptr_(idtptr), w_(w), dim_(dim) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[w_ * idtptr_[idx] + idx % dim_] = static_cast(1); + } + + T* output_; + int* idtptr_; + int w_; + int dim_; +}; + +template +void LU_Unpack(const DeviceContext& dev_ctx, const framework::Tensor* LU, + framework::Tensor* L, framework::Tensor* U) { + const auto udims = LU->dims(); + L->Resize(udims); + U->Resize(udims); + const auto H = udims[udims.size() - 2]; + const auto W = udims[udims.size() - 1]; + auto L_dataptr = L->mutable_data(dev_ctx.GetPlace()); + platform::ForRange x_for_range(dev_ctx, LU->numel()); + TrilTriuCompute tril_computer(LU->data(), -1, true, H, W, L_dataptr); + x_for_range(tril_computer); + + TrilTriuCompute triu_computer(LU->data(), 0, false, H, W, + U->mutable_data(dev_ctx.GetPlace())); + x_for_range(triu_computer); + + // set L's diagonal 1 + auto dim = std::min(H, W); + framework::Tensor rowtensor, rt_dev; + auto batchsize = product(framework::slice_ddim(udims, 0, udims.size() - 2)); + batchsize = std::max(static_cast(batchsize), 1); + arange(dev_ctx, &rowtensor, dim, batchsize, H); + auto idtptr = rowtensor.data(); + if (is_gpu_place(dev_ctx.GetPlace())) { + framework::TensorCopy(rowtensor, dev_ctx.GetPlace(), &rt_dev); + idtptr = rt_dev.data(); + } + + platform::ForRange for_range(dev_ctx, rowtensor.numel()); + OneFunctor functor(L_dataptr, idtptr, W, dim); + for_range(functor); +} + +template +void scatterpivot(const DeviceContext& dev_ctx, T* out_data, + framework::Tensor* idlst, int w, int dim) { + framework::Tensor idlst_tmp; + idlst_tmp.Resize(idlst->dims()); + idlst_tmp.mutable_data(dev_ctx.GetPlace()); + framework::TensorCopy(*idlst, dev_ctx.GetPlace(), &idlst_tmp); + auto idtptr = idlst_tmp.data(); + + platform::ForRange for_range(dev_ctx, idlst_tmp.numel()); + OneFunctor functor(out_data, idtptr, w, dim); + for_range(functor); +} + +template +void Unpack_Pivot(const DeviceContext& dev_ctx, const framework::Tensor& Pivot, + framework::Tensor* P, int h, int w) { + auto dims = Pivot.dims(); + auto Pdimvec = vectorize(dims); + auto prank = Pdimvec.size(); + auto Pnum = dims[prank - 1]; + framework::Tensor Pivot_cpu; + platform::CPUPlace cpu; + framework::TensorCopy(Pivot, cpu, &Pivot_cpu); + auto pdataptr = Pivot_cpu.data(); + Pdimvec[prank - 1] = h; + Pdimvec.emplace_back(h); + auto Pdim = framework::make_ddim(Pdimvec); + P->Resize(Pdim); + auto pdata = P->mutable_data(dev_ctx.GetPlace()); + math::SetConstant setter; + setter(dev_ctx, P, static_cast(0)); + + auto batchsize = product(framework::slice_ddim(dims, 0, prank - 1)); + batchsize = std::max(static_cast(batchsize), 1); + framework::Tensor idt; + for (int i = 0; i < batchsize; i++) { + arange(dev_ctx, &idt, h); + auto idlst = idt.data(); + for (int j = 0; j < Pnum; j++) { + if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue; + auto temp = idlst[j]; + idlst[j] = idlst[pdataptr[i * Pnum + j] - 1]; + idlst[pdataptr[i * Pnum + j] - 1] = temp; + } + scatterpivot(dev_ctx, &(pdata[i * h * h]), &idt, h, h); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index b4b6d50e55e..f9dc6baea3c 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -71,6 +71,10 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnSpotrsBatched); \ __macro(cusolverDnDpotrsBatched); \ __macro(cusolverDnSgesvdj_bufferSize); \ + __macro(cusolverDnSgetrf_bufferSize); \ + __macro(cusolverDnDgetrf_bufferSize); \ + __macro(cusolverDnCgetrf_bufferSize); \ + __macro(cusolverDnZgetrf_bufferSize); \ __macro(cusolverDnSgeqrf_bufferSize); \ __macro(cusolverDnDgeqrf_bufferSize); \ __macro(cusolverDnCgeqrf_bufferSize); \ @@ -84,6 +88,10 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnDgesvdj_bufferSize); \ __macro(cusolverDnSgesvdj); \ __macro(cusolverDnDgesvdj); \ + __macro(cusolverDnSgetrf); \ + __macro(cusolverDnDgetrf); \ + __macro(cusolverDnCgetrf); \ + __macro(cusolverDnZgetrf); \ __macro(cusolverDnSgeqrf); \ __macro(cusolverDnDgeqrf); \ __macro(cusolverDnCgeqrf); \ diff --git a/python/paddle/fluid/tests/unittests/test_lu_op.py b/python/paddle/fluid/tests/unittests/test_lu_op.py new file mode 100644 index 00000000000..badd713132c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lu_op.py @@ -0,0 +1,171 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from op_test import OpTest +import unittest +import itertools +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import scipy +import scipy.linalg +import copy + + +def scipy_lu(A, pivot): + shape = A.shape + if len(shape) == 2: + return scipy.linalg.lu(A, permute_l=not pivot) + else: + preshape = shape[:-2] + batchsize = np.product(shape) // (shape[-2] * shape[-1]) + PP = [] + PL = [] + PU = [] + NA = A.reshape((-1, shape[-2], shape[-1])) + for b in range(batchsize): + P, L, U = scipy.linalg.lu(NA[b], permute_l=not pivot) + pshape = P.shape + lshape = L.shape + ushape = U.shape + PP.append(P) + PL.append(L) + PU.append(U) + return np.array(PP).reshape(preshape + pshape), np.array(PL).reshape( + preshape + lshape), np.array(PU).reshape(preshape + ushape) + + +def Pmat_to_perm(Pmat_org, cut): + Pmat = copy.deepcopy(Pmat_org) + shape = Pmat.shape + rows = shape[-2] + cols = shape[-1] + batchsize = max(1, np.product(shape[:-2])) + P = Pmat.reshape(batchsize, rows, cols) + permmat = [] + for b in range(batchsize): + permlst = [] + sP = P[b] + for c in range(min(rows, cols)): + idx = np.argmax(sP[:, c]) + permlst.append(idx) + tmp = copy.deepcopy(sP[c, :]) + sP[c, :] = sP[idx, :] + sP[idx, :] = tmp + + permmat.append(permlst) + Pivot = np.array(permmat).reshape(list(shape[:-2]) + [rows, ]) + 1 + return Pivot[..., :cut] + + +def perm_to_Pmat(perm, dim): + pshape = perm.shape + bs = int(np.product(perm.shape[:-1]).item()) + perm = perm.reshape((bs, pshape[-1])) + oneslst = [] + for i in range(bs): + idlst = np.arange(dim) + perm_item = perm[i, :] + for idx, p in enumerate(perm_item - 1): + temp = idlst[idx] + idlst[idx] = idlst[p] + idlst[p] = temp + + ones = paddle.eye(dim) + nmat = paddle.scatter(ones, paddle.to_tensor(idlst), ones) + oneslst.append(nmat) + return np.array(oneslst).reshape(list(pshape[:-1]) + [dim, dim]) + + +# m < n +class TestLUOp(OpTest): + """ + case 1 + """ + + def config(self): + self.x_shape = [3, 10, 12] + self.pivot = True + self.get_infos = True + self.dtype = "float64" + + def set_output(self): + X = self.inputs['X'] + sP, sl, sU = scipy_lu(X, self.pivot) + sL = np.tril(sl, -1) + ashape = np.array(X.shape) + lshape = np.array(sL.shape) + ushape = np.array(sU.shape) + + lpad = (len(sL.shape) - 2) * [(0, 0)] + list(( + (0, (ashape - lshape)[-2]), (0, (ashape - lshape)[-1]))) + upad = (len(sU.shape) - 2) * [(0, 0)] + list(( + (0, (ashape - ushape)[-2]), (0, (ashape - ushape)[-1]))) + + NsL = np.pad(sL, lpad) + NsU = np.pad(sU, upad) + NLU = NsL + NsU + self.output = NLU + self.Pivots = Pmat_to_perm(sP, min(ashape[-2], ashape[-1])) + self.Infos = np.zeros(self.x_shape[:-2]) if len( + X.shape) > 2 else np.array([0]) + + def setUp(self): + self.op_type = "lu" + self.config() + + self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)} + self.attrs = {'pivots': self.pivot} + self.set_output() + self.outputs = { + 'Out': self.output, + 'Pivots': self.Pivots, + 'Infos': self.Infos + } + + def test_check_output(self): + self.check_output() + + +# m = n 2D +class TestLUOp2(TestLUOp): + """ + case 2 + """ + + def config(self): + self.x_shape = [10, 10] + self.pivot = True + self.get_infos = True + self.dtype = "float64" + + +# m > n +class TestLUOp3(TestLUOp): + """ + case 3 + """ + + def config(self): + self.x_shape = [2, 12, 10] + self.pivot = True + self.get_infos = True + self.dtype = "float64" + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index f3026dc1fae..694283264ca 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -727,6 +727,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_class_center_sample_op', 'test_fill_diagonal_tensor_op', 'test_fill_any_op', + 'test_lu_op', 'test_margin_cross_entropy_op', 'test_pull_gpups_sparse_op', ] -- GitLab