From 7e5fb4623b6d779d9ac13b3719ece5306dbdb938 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 2 Sep 2021 19:29:49 +0800 Subject: [PATCH] Add SVD Op and it's GPU and CPU kernel (#34953) * Add SVD Op and it's GPU and CPU kernel * Remove CUDAPlace in test_svd_op, make the test available in CPU package * modfity the file * fix windows bug/ fix ROCM / fix test timeout * for pass the CIs * improve error report * for code review * some modification to test_svd_op * change python code style * expose the svd interface for document --- cmake/operators.cmake | 1 + paddle/fluid/operators/svd_helper.h | 372 ++++++++++++++++++ paddle/fluid/operators/svd_op.cc | 163 ++++++++ paddle/fluid/operators/svd_op.cu | 175 ++++++++ paddle/fluid/operators/svd_op.h | 145 +++++++ paddle/fluid/platform/dynload/cusolver.h | 8 +- python/paddle/__init__.py | 2 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_svd_op.py | 292 ++++++++++++++ .../white_list/no_check_set_white_list.py | 1 + .../white_list/op_threshold_white_list.py | 1 + python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 64 +++ 14 files changed, 1228 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/svd_helper.h create mode 100644 paddle/fluid/operators/svd_op.cc create mode 100644 paddle/fluid/operators/svd_op.cu create mode 100644 paddle/fluid/operators/svd_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_svd_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index a200b948de..7730550e06 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -183,6 +183,7 @@ function(op_library TARGET) list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM hip_srcs "cholesky_op.cu") + list(REMOVE_ITEM hip_srcs "svd_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h new file mode 100644 index 0000000000..aa6a369728 --- /dev/null +++ b/paddle/fluid/operators/svd_helper.h @@ -0,0 +1,372 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/functors.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +using InTensors = std::vector; +using OutTensors = std::vector; +using OpName = std::string; + +template +void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, + int full = false) { + auto flag = Eigen::DecompositionOptions::ComputeThinU | + Eigen::DecompositionOptions::ComputeThinV; + if (full) { + flag = Eigen::DecompositionOptions::ComputeFullU | + Eigen::DecompositionOptions::ComputeFullV; + } + Eigen::BDCSVD< + Eigen::Matrix> + svd(2, 2, flag); + /*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/ + T* input = const_cast(X); + auto m = Eigen::Map< + Eigen::Matrix>( + input, rows, cols); + svd.compute(m); + Eigen::Matrix V_trans = + svd.matrixV().transpose(); + memcpy(U, svd.matrixU().data(), svd.matrixU().size() * sizeof(T)); + memcpy(VH, V_trans.data(), V_trans.size() * sizeof(T)); + memcpy(S, svd.singularValues().data(), + svd.singularValues().size() * sizeof(T)); +} + +template +void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches, + int full = false) { + int stride = rows * cols; + int k = std::min(rows, cols); + int stride_u = full ? rows * rows : k * rows; + int stride_v = full ? cols * cols : k * cols; + for (int i = 0; i < batches; ++i) { + EigenSvd(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k, + rows, cols, full); + } + return; +} + +template +struct PowFunctor { + PowFunctor(const T* input, T* output, int64_t numel, float exp) + : input_(input), output_(output), numel_(numel), exp_(exp) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = pow(input_[idx], exp_); + } + const T* input_; + T* output_; + int64_t numel_; + float exp_; +}; + +static std::vector GetBroadcastShape(InTensors ins) { + // TODO(xiongkun03) check the operators and output + PADDLE_ENFORCE_EQ(ins.size(), 2, platform::errors::InvalidArgument( + "GetBroadcastShape Receive 2 tensors" + "but got [%d]", + ins.size())); + auto x_dim = ins[0]->dims(); + auto y_dim = ins[1]->dims(); + std::vector broadcast_shape = + (x_dim.size() > y_dim.size() ? framework::vectorize(x_dim) + : framework::vectorize(y_dim)); + int rank_min = std::min(x_dim.size(), y_dim.size()); + int rank_x = x_dim.size(); + int rank_y = y_dim.size(); + int final_rank = broadcast_shape.size(); + for (int i = 1; i <= rank_min; ++i) { + if (x_dim[rank_x - i] == y_dim[rank_y - i]) { + broadcast_shape[final_rank - i] = x_dim[rank_x - i]; + continue; + } + if (x_dim[rank_x - i] == 1) { + broadcast_shape[final_rank - i] = y_dim[rank_y - i]; + continue; + } + if (y_dim[rank_y - i] == 1) { + broadcast_shape[final_rank - i] = x_dim[rank_x - i]; + continue; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "Wrong Input Shape in broadcast operator: " + "Input(X)'s shape must follow the broadcast rule with Input(Y)'s " + "shape, but received [%s] (X) vs [%s] (Y).", + x_dim, y_dim)); + } + return broadcast_shape; +} + +template +struct DeviceIndependenceTensorOperations { + // 1. Device indenpendence, for kernel reuse. + // 2. Input and output is always tensor type. + // 3. output Tensor is alway allocated + // 4. Basic Tensor operator is supported + // 5. The Reused Operator Kernel should only be considered as + // a wrap function + using NameInTensorMap = + std::map>; + using NameOutTensor = std::vector; + + explicit DeviceIndependenceTensorOperations( + const framework::ExecutionContext& context) + : context(context) {} + + framework::Tensor Pow(const framework::Tensor& x, float exp) { + framework::Tensor out; + auto for_range = GetForRange(x.numel()); + int numel = x.numel(); + PowFunctor functor(x.data(), out.mutable_data(x.dims(), x.place()), + numel, exp); + for_range(functor); + return out; + } + framework::Tensor Matmul(const framework::Tensor& mat_a, + const framework::Tensor& mat_b, bool trans_a = false, + bool trans_b = false) { + framework::AttributeMap attrs; + attrs["trans_x"] = trans_a; + attrs["trans_y"] = trans_b; + NameInTensorMap inputs({{"X", {&mat_a}}, {"Y", {&mat_b}}}); + auto a_dim = mat_a.dims(); + auto b_dim = mat_b.dims(); + std::vector x_vec = framework::vectorize(a_dim); + x_vec[x_vec.size() - 2] = a_dim[a_dim.size() - (trans_a ? 1 : 2)]; + x_vec[x_vec.size() - 1] = b_dim[b_dim.size() - (trans_b ? 2 : 1)]; + return CreateOpRunAndReturnTensor("matmul_v2", inputs, attrs, x_vec); + } + // transpose the last two dimision + framework::Tensor Transpose(const framework::Tensor& x) { + framework::Tensor out; + auto x_dim = x.dims(); + auto x_vec = framework::vectorize(x_dim); + int rank = x_vec.size(); + std::swap(x_vec[rank - 1], x_vec[rank - 2]); + std::vector out_shape = x_vec; + std::vector axis(rank); + for (int i = 0; i < rank; ++i) { + axis[i] = i; + } + std::swap(axis[rank - 1], axis[rank - 2]); + framework::AttributeMap attrs; + attrs["axis"] = axis; + NameInTensorMap inputs({{"X", {&x}}}); + return CreateOpRunAndReturnTensor("transpose2", inputs, attrs, out_shape, + {"Out", "XShape"}); + } + + framework::Tensor Diag(const framework::Tensor& x, int offset = 0, + int padding_value = 0) { + framework::AttributeMap attrs; + attrs["offset"] = offset; + attrs["padding_value"] = padding_value; + NameInTensorMap inputs({{"X", {&x}}}); + int x_rank = x.dims().size(); + std::vector out_shape; + if (x_rank == 2) { + PADDLE_ENFORCE_EQ(x.dims()[0], x.dims()[1], + platform::errors::InvalidArgument( + "if X is a Matrix, then X must be square")); + out_shape.push_back(x.dims()[0]); + } else if (x_rank == 1) { + out_shape.push_back(x.dims()[0]); + out_shape.push_back(x.dims()[0]); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Rank must less or equal than 2")); + } + return CreateOpRunAndReturnTensor("diag_v2", inputs, attrs, out_shape); + } + + framework::Tensor Add(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_add", inputs, attrs, + out_shape); + } + + framework::Tensor Mul(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_mul", inputs, attrs, + out_shape); + } + + framework::Tensor Sub(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_sub", inputs, attrs, + out_shape); + } + + const framework::Tensor Unsqueeze(const framework::Tensor& x, int axis = 0) { + // don't copy data, only change the dims + framework::Tensor out; + out.ShareDataWith(x); + std::vector out_shape = framework::vectorize(x.dims()); + if (axis >= 0) { + auto index = (out_shape.begin() + axis); + out_shape.insert(index, 1); + } else if (axis < 0) { + auto index = (out_shape.end() + axis + 1); + out_shape.insert(index, 1); + } + out.Resize(framework::make_ddim(out_shape)); + return out; + } + + framework::Tensor Zeros(std::vector shape, + framework::proto::VarType::Type dtype, + float fill_value) { + framework::AttributeMap attrs; + attrs["dtype"] = dtype; + attrs["shape"] = shape; + attrs["value"] = fill_value; + NameInTensorMap inputs({}); + return CreateOpRunAndReturnTensor("fill_constant", inputs, attrs, shape); + } + + framework::Tensor Infinits(std::vector shape, + framework::proto::VarType::Type dtype) { + framework::AttributeMap attrs; + attrs["dtype"] = dtype; + attrs["shape"] = shape; + attrs["str_value"] = std::string("inf"); + NameInTensorMap inputs({}); + return CreateOpRunAndReturnTensor("fill_constant", inputs, attrs, shape); + } + + framework::Tensor Eye(int n, framework::proto::VarType::Type dtype) { + auto output = Zeros({n}, dtype, 1); + auto ret = Diag(output); + return ret; + } + + framework::Tensor Slice(const framework::Tensor& x, std::vector axes, + std::vector starts, std::vector ends) { + std::vector new_axes = axes; + NameInTensorMap inputs({{"Input", {&x}}}); + std::vector out_shape = framework::vectorize(x.dims()); + int rank = out_shape.size(); + PADDLE_ENFORCE_EQ( + axes.size(), starts.size(), + platform::errors::InvalidArgument("Slice Operator Argument Invalided")); + PADDLE_ENFORCE_EQ( + ends.size(), starts.size(), + platform::errors::InvalidArgument("Slice Operator Argument Invalided")); + for (unsigned int i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + if (axis < 0) axis = rank + axis; + new_axes[i] = axis; // change negative to positive + int st = starts[i]; + int ed = ends[i]; + PADDLE_ENFORCE_GT(ed, st, + platform::errors::InvalidArgument( + "C++ Slice Operation Not Support End < Start")); + out_shape[axis] = ed - st; + } + framework::AttributeMap attrs; + attrs["axes"] = new_axes; + attrs["starts"] = starts; + attrs["ends"] = ends; + return CreateOpRunAndReturnTensor("slice", inputs, attrs, out_shape); + } + + private: + const framework::ExecutionContext& context; + BlasT GetBlas() { + return math::GetBlas(context); + } + platform::ForRange GetForRange(int numel) { + auto& dev_ctx = context.template device_context(); + return platform::ForRange(dev_ctx, numel); + } + + framework::Tensor CreateOpRunAndReturnTensor( + const std::string& type, const NameInTensorMap& inputs, + const framework::AttributeMap& attrs, std::vector out_shape, + NameOutTensor out_str = {"Out"}) { + // varialble set dims must be LoDTensor / SelectedRowTensor + framework::Scope& local_scope = context.scope().NewScope(); + + framework::VariableNameMap op_outputs; + for (auto out_name : out_str) { + local_scope.Var("tmp_" + out_name)->GetMutable(); + op_outputs[out_name].emplace_back("tmp_" + out_name); + } + auto out_var = local_scope.Var("tmp_Out"); // return the Out + // create Out Tensor and allocat memory + out_var->GetMutable()->mutable_data( + framework::make_ddim(out_shape), context.GetPlace()); + // framework::make_ddim(out_shape) + framework::VariableNameMap op_inputs; + int counter = 0; + for (auto item : inputs) { + auto& tensors = item.second; + std::vector name_vector; + for (auto each_tensor : tensors) { + // create score variable and reset the tensor. + std::string _name = "tmp" + std::to_string(counter++); + auto in_var = local_scope.Var(_name); // create + framework::LoDTensor tmp_tns; + tmp_tns.ShareDataWith(*each_tensor); // tensor -> lodtensor + (*in_var->GetMutable()) = + tmp_tns; // initialize and set value + name_vector.emplace_back(_name); + } + op_inputs[item.first] = name_vector; + } + auto op = + framework::OpRegistry::CreateOp(type, op_inputs, op_outputs, attrs); + op->Run(local_scope, context.GetPlace()); + framework::Tensor out; + out.ShareDataWith(*(out_var->GetMutable())); + out.Resize(framework::make_ddim(out_shape)); + context.scope().DeleteScope(&local_scope); + return out; + } +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc new file mode 100644 index 0000000000..90c138c578 --- /dev/null +++ b/paddle/fluid/operators/svd_op.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/svd_op.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using DDim = framework::DDim; +static DDim UDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 1] = k; + return framework::make_ddim(x_vec); +} +static DDim VHDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + return framework::make_ddim(x_vec); +} +static DDim SDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + x_vec.erase(x_vec.end() - 1); // rank - 1 + return framework::make_ddim(x_vec); +} + +class SvdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd"); + + auto in_dims = ctx->GetInputDim("X"); + int x_rank = in_dims.size(); + PADDLE_ENFORCE_GE(in_dims.size(), 2, + platform::errors::InvalidArgument( + "the rank of input must greater than 2")); + int m = in_dims[x_rank - 2]; + int n = in_dims[x_rank - 1]; + int k = std::min(m, n); + const bool full_uv = ctx->Attrs().Get("full_matrices"); + ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m)); + ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n)); + ctx->SetOutputDim("S", SDDim(in_dims, k)); + + ctx->ShareLoD("X", /*->*/ "U"); + ctx->ShareLoD("X", /*->*/ "VH"); + ctx->ShareLoD("X", /*->*/ "S"); + } +}; + +class SvdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of svd op."); + AddOutput("U", "(Tensor), The output U tensor of svd op."); + AddOutput("S", "(Tensor), The output S tensor of svd op."); + AddOutput("VH", "(Tensor), The output VH tensor of svd op."); + AddAttr("full_matrices", + "(bool, default false) Only Compute the thin U and V" + "when set as True, the gradient have some random " + "attribute.") + .SetDefault(false); + AddComment(R"DOC( +Svd Operator. + +This operator is used to perform SVD operation for batched matrics $X$. +$$U, S, VH = svd(X)$$ + +)DOC"); + } +}; + +class SvdGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")), "Input", + "U@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("VH")), "Input", + "VH@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("S")), "Input", + "S@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("S"), "Input", "S", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("VH"), "Input", "VH", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "SvdGrad"); + + auto d_x = ctx->GetInputDim(("X")); + ctx->SetOutputDim(framework::GradVarName("X"), d_x); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class SvdGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("svd_grad"); + retv->SetInput(framework::GradVarName("U"), this->OutputGrad("U")); + retv->SetInput(framework::GradVarName("VH"), this->OutputGrad("VH")); + retv->SetInput(framework::GradVarName("S"), this->OutputGrad("S")); + retv->SetInput("U", this->Output("U")); + retv->SetInput("VH", this->Output("VH")); + retv->SetInput("S", this->Output("S")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(svd, ops::SvdOp, ops::SvdOpMaker, + ops::SvdGradMaker, + ops::SvdGradMaker); + +REGISTER_OPERATOR(svd_grad, ops::SvdGradOp); + +REGISTER_OP_CPU_KERNEL(svd, ops::SvdCPUKernel, + ops::SvdCPUKernel); + +REGISTER_OP_CPU_KERNEL( + svd_grad, ops::SvdGradKernel, + ops::SvdGradKernel); diff --git a/paddle/fluid/operators/svd_op.cu b/paddle/fluid/operators/svd_op.cu new file mode 100644 index 0000000000..ade7496d64 --- /dev/null +++ b/paddle/fluid/operators/svd_op.cu @@ -0,0 +1,175 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/svd_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace operators { + +template +class SvdGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = + context.template device_context(); + + const Tensor* x = context.Input("X"); + Tensor* U = context.Output("U"); + Tensor* VH = context.Output("VH"); + Tensor* S = context.Output("S"); + const bool full_matrices = context.Attr("full_matrices"); + + auto& dims = x->dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + int rank = dims.size(); + int m = dims[rank - 2]; + int n = dims[rank - 1]; + + auto* vh_data = VH->mutable_data(context.GetPlace()); + auto* s_data = S->mutable_data(context.GetPlace()); + auto* u_data = U->mutable_data(context.GetPlace()); + // NOTE:(@xiongkun03) + // matrices are assumed to be stored in column-major order in cusolver + // then view A as n x m and do A^T SVD, we can avoid transpose + // Must Copy X once, because the gesvdj will change the origin input matrix + Tensor x_tmp; + TensorCopy(*x, context.GetPlace(), &x_tmp); + auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count); + int* info_ptr = reinterpret_cast(info->ptr()); + + GesvdjBatched(dev_ctx, batch_count, n, m, std::min(m, n), + x_tmp.mutable_data(context.GetPlace()), vh_data, u_data, + s_data, info_ptr, !full_matrices); + + framework::DDim UT_dim = U->dims(); + std::swap(UT_dim[rank - 1], UT_dim[rank - 2]); // Get the dim of UT_dim + U->Resize(UT_dim); // U is entirely UT + auto dito = + math::DeviceIndependenceTensorOperations(context); + auto tmp_U = dito.Transpose(*U); + U->ShareDataWith(tmp_U); // U becomse UT, aka VT + } + void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, + int m, int n, int k, T* A, T* U, T* V, T* S, int* info, + int thin_UV = 1) const; +}; + +template <> +void SvdGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, float* A, float* U, float* V, float* S, int* info, + int thin_UV) const { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + // check the error info + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void SvdGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, double* A, double* U, double* V, double* S, int* info, + int thin_UV) const { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + // check the error info + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(svd, ops::SvdGPUKernel, + ops::SvdGPUKernel); +REGISTER_OP_CUDA_KERNEL( + svd_grad, ops::SvdGradKernel, + ops::SvdGradKernel); +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/svd_op.h b/paddle/fluid/operators/svd_op.h new file mode 100644 index 0000000000..1910effbea --- /dev/null +++ b/paddle/fluid/operators/svd_op.h @@ -0,0 +1,145 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class SvdCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* U = context.Output("U"); + Tensor* VH = context.Output("VH"); + Tensor* S = context.Output("S"); + int full = context.Attr("full_matrices"); + + /*Create Tensors and output, set the dim ...*/ + auto numel = x->numel(); + auto* x_data = x->data(); + auto x_dims = x->dims(); + int rows = x_dims[x_dims.size() - 2]; + int cols = x_dims[x_dims.size() - 1]; + int k = std::min(rows, cols); + int col_u = full ? rows : k; + int col_v = full ? cols : k; + int batches = numel / (rows * cols); + auto* U_out = U->mutable_data>( + context.GetPlace(), + size_t(batches * rows * col_u * sizeof(math::Real))); + auto* VH_out = VH->mutable_data>( + context.GetPlace(), + size_t(batches * col_v * cols * sizeof(math::Real))); + auto* S_out = S->mutable_data>( + context.GetPlace(), size_t(batches * k * sizeof(math::Real))); + + /*SVD Use the Eigen Library*/ + math::BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); + } +}; + +template +class SvdGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor& U_const = *ctx.Input("U"); + const framework::Tensor& VH_const = *ctx.Input("VH"); + const framework::Tensor& S = *ctx.Input("S"); + framework::Tensor& dX = + *ctx.Output(framework::GradVarName("X")); + const framework::Tensor& dU_const = + *ctx.Input(framework::GradVarName("U")); + const framework::Tensor& dVH_const = + *ctx.Input(framework::GradVarName("VH")); + + const bool full = ctx.Attr("full_matrices"); + int m = dX.dims()[dX.dims().size() - 2]; + int n = dX.dims()[dX.dims().size() - 1]; + int k = S.dims()[S.dims().size() - 1]; + auto dito = math::DeviceIndependenceTensorOperations(ctx); + framework::Tensor U, VH, dU, dV, dVH; + if (full) { + // if full_matrices is set, slice the U and VT to k columns + U = dito.Slice(U_const, {-1}, {0}, {k}); + VH = dito.Slice(VH_const, {-2}, {0}, {k}); + dU = dito.Slice(dU_const, {-1}, {0}, {k}); + dVH = dito.Slice(dVH_const, {-2}, {0}, {k}); + } else { + U = U_const; + VH = VH_const; + dU = dU_const; + dVH = dVH_const; + } + auto s_inverse = dito.Pow(S, -1); + auto s_square = dito.Pow(S, 2); + auto F = + dito.Sub(dito.Unsqueeze(s_square, -2), dito.Unsqueeze(s_square, -1)); + F = dito.Add(F, dito.Diag(dito.Infinits({k}, U.type()))); + F = dito.Pow(F, -1); + Tensor sigma_term; + Tensor u_term; + Tensor v_term; + + if (ctx.HasInput(framework::GradVarName("S"))) { + const framework::Tensor& gS = + *ctx.Input(framework::GradVarName("S")); + sigma_term = dito.Mul(dito.Unsqueeze(gS, -2), U); + sigma_term = dito.Matmul(sigma_term, VH); + } + + if (ctx.HasInput(framework::GradVarName("U"))) { + auto UTG = dito.Matmul(U, dU, true, false); + auto GTU = dito.Matmul(dU, U, true, false); + u_term = dito.Mul(dito.Mul(dito.Sub(UTG, GTU), F), dito.Unsqueeze(S, -2)); + u_term = dito.Matmul(U, u_term); + if (m > k) { + auto project = + dito.Sub(dito.Eye(m, U.type()), dito.Matmul(U, U, false, true)); + u_term = dito.Add(u_term, dito.Mul(dito.Matmul(project, dU), + dito.Unsqueeze(s_inverse, -2))); + } + u_term = dito.Matmul(u_term, VH); + } + + if (ctx.HasInput(framework::GradVarName("VH"))) { + auto UTG = dito.Matmul(VH, dVH, false, true); + auto GTU = dito.Matmul(dVH, VH, false, true); + v_term = dito.Mul(dito.Matmul(dito.Mul(dito.Sub(UTG, GTU), F), VH), + dito.Unsqueeze(S, -1)); + if (n > k) { + auto project = + dito.Sub(dito.Eye(n, U.type()), dito.Matmul(VH, VH, true, false)); + v_term = dito.Add(v_term, dito.Mul(dito.Matmul(dVH, project), + dito.Unsqueeze(s_inverse, -1))); + } + v_term = dito.Matmul(U, v_term); + } + + dX.ShareDataWith(dito.Add(dito.Add(u_term, sigma_term), v_term)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 561f20af45..42583b6068 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -55,7 +55,13 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); #if CUDA_VERSION >= 9020 #define CUSOLVER_ROUTINE_EACH_R1(__macro) \ __macro(cusolverDnSpotrfBatched); \ - __macro(cusolverDnDpotrfBatched); + __macro(cusolverDnDpotrfBatched); \ + __macro(cusolverDnSgesvdj_bufferSize); \ + __macro(cusolverDnDestroyGesvdjInfo); \ + __macro(cusolverDnCreateGesvdjInfo); \ + __macro(cusolverDnDgesvdj_bufferSize); \ + __macro(cusolverDnSgesvdj); \ + __macro(cusolverDnDgesvdj); CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ce338275b2..24a7a666fb 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -100,6 +100,7 @@ from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 +from .tensor.linalg import svd # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 @@ -493,6 +494,7 @@ __all__ = [ # noqa 'sqrt', 'cholesky', 'matrix_power', + 'svd', 'randperm', 'linspace', 'reshape', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5ca9624b98..2c001614d1 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -858,6 +858,7 @@ set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static PROPER set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_svd_op PROPERTIES TIMEOUT 120) set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_svd_op.py b/python/paddle/fluid/tests/unittests/test_svd_op.py new file mode 100644 index 0000000000..c2d712b3d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_svd_op.py @@ -0,0 +1,292 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +from gradient_checker import grad_check +from decorator_helper import prog_scope + + +class TestSvdOp(OpTest): + def setUp(self): + paddle.enable_static() + self.generate_input() + self.generate_output() + self.op_type = "svd" + assert (hasattr(self, "_output_data")) + self.inputs = {"X": self._input_data} + self.attrs = {'full_matrices': self.get_full_matrices_option()} + self.outputs = { + "U": self._output_data[0], + "S": self._output_data[1], + "VH": self._output_data[2] + } + + def generate_input(self): + """ return a input_data and input_shape + """ + self._input_shape = (100, 1) + self._input_data = np.random.random(self._input_shape).astype("float64") + + def get_full_matrices_option(self): + return False + + def generate_output(self): + assert (hasattr(self, "_input_data")) + self._output_data = np.linalg.svd(self._input_data) + + def test_check_output(self): + self.check_output(no_check_set=['U', 'VH']) + + def test_svd_forward(self): + """ u matmul diag(s) matmul vt must become X + """ + single_input = self._input_data.reshape( + [-1, self._input_shape[-2], self._input_shape[-1]])[0] + paddle.disable_static() + dy_x = paddle.to_tensor(single_input) + dy_u, dy_s, dy_vt = paddle.linalg.svd(dy_x) + dy_out_x = dy_u.matmul(paddle.diag(dy_s)).matmul(dy_vt) + if (paddle.abs(dy_out_x - dy_x) < 1e-7).all(): + ... + else: + print("EXPECTED:\n", dy_x) + print("GOT :\n", dy_out_x) + raise RuntimeError("Check SVD Failed") + paddle.enable_static() + + def check_S_grad(self): + self.check_grad(['X'], ['S'], numeric_grad_delta=0.001) + + def check_U_grad(self): + self.check_grad(['X'], ['U'], numeric_grad_delta=0.001) + + def check_V_grad(self): + self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001) + + def test_check_grad(self): + """ + remember the input matrix must be the full rank matrix, otherwise the gradient will stochatic because the u / v 's (n-k) freedom vectors + """ + self.check_S_grad() + self.check_U_grad() + self.check_V_grad() + + +class TestSvdCheckGrad2(TestSvdOp): + # NOTE(xiongkun03): because we want to construct some full rank matrics, + # so we can't specifize matrices which numel() > 100 + + no_need_check_grad = True + + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (5, 5) + self._input_data = np.vander( + [2, 3, 4, 5, 6]).astype("float64").reshape(self._input_shape) + + +class TestSvdNormalMatrixSmall(TestSvdCheckGrad2): + def generate_input(self): + """ small matrix SVD. + """ + self._input_shape = (1, 1) + self._input_data = np.random.random(self._input_shape).astype("float64") + + +class TestSvdNormalMatrix6x3(TestSvdCheckGrad2): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (6, 3) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + + +class TestSvdNormalMatrix3x6(TestSvdCheckGrad2): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + + +class TestSvdNormalMatrix6x3Batched(TestSvdOp): + def generate_input(self): + self._input_shape = (10, 6, 3) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = np.stack([self._input_data] * 10, axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +class TestSvdNormalMatrix3x6Batched(TestSvdOp): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (10, 3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + self._input_data = np.stack([self._input_data] * 10, axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +class TestSvdNormalMatrix3x3x3x6Batched(TestSvdOp): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (3, 3, 3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + self._input_data = np.stack( + [self._input_data, self._input_data, self._input_data], axis=0) + self._input_data = np.stack( + [self._input_data, self._input_data, self._input_data], axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " + + "however it is desirable to cover the forward pass") +class TestSvdNormalMatrixBig(TestSvdOp): + def generate_input(self): + """ big matrix SVD. + + """ + self._input_shape = (2, 200, 300) + self._input_data = np.random.random(self._input_shape).astype("float64") + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + def test_check_grad(self): + pass + + +class TestSvdNormalMatrixBig2(TestSvdOp): + def generate_input(self): + """ big matrix SVD. + """ + self._input_shape = (1, 100) + self._input_data = np.random.random(self._input_shape).astype("float64") + + +class TestSvdNormalMatrixFullMatrices(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_full_matrices(self): + mat_shape = (2, 3) + mat = np.random.random(mat_shape).astype("float64") + x = paddle.to_tensor(mat) + u, s, vh = paddle.linalg.svd(x, full_matrices=True) + assert (u.shape == [2, 2]) + assert (vh.shape == [3, 3]) + x_recover = u.matmul(paddle.diag(s)).matmul(vh[0:2]) + if ((paddle.abs(x_recover - x) > 1e-4).any()): + raise RuntimeError("mat can't be recovered\n") + + +class TestSvdFullMatriceGrad(TestSvdNormalMatrix6x3): + def get_full_matrices_option(self): + return True + + def test_svd_forward(self): + """ test_svd_forward not support full matrices, so disable this test. + """ + pass + + def test_check_grad(self): + """ + remember the input matrix must be the full rank matrix, otherwise the gradient will stochatic because the u / v 's (n-k) freedom vectors + """ + self.check_S_grad() + #self.check_U_grad() // don't check U grad, because U have freedom vector + self.check_V_grad() + + +class TestSvdAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + a = np.random.rand(5, 5) + x = paddle.to_tensor(a) + u, s, vh = paddle.linalg.svd(x) + gt_u, gt_s, gt_vh = np.linalg.svd(a, full_matrices=False) + self.assertTrue(np.allclose(s, gt_s)) + + def test_static(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + a = np.random.rand(5, 5) + x = paddle.fluid.data( + name="input", shape=[5, 5], dtype='float64') + u, s, vh = paddle.linalg.svd(x) + exe = fluid.Executor(place) + gt_u, gt_s, gt_vh = np.linalg.svd(a, full_matrices=False) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[s]) + self.assertTrue(np.allclose(fetches[0], gt_s)) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 2492caff2f..584c418675 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -31,5 +31,6 @@ no_check_set_white_list = [ 'rnn', 'fusion_lstm', 'softmax_with_cross_entropy', + 'svd', 'class_center_sample', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 929a9696d1..2b3383239a 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -46,6 +46,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'cudnn_lstm', \ 'rnn', \ 'lgamma', \ + 'svd', \ 'matrix_power', \ ] diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index ec6b7aa9e3..236150eef9 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,10 +16,12 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor.linalg import svd __all__ = [ 'cholesky', #noqa 'norm', 'inv', + 'svd', 'matrix_power' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 040bec2f67..375375c860 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -45,6 +45,7 @@ from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import svd # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -223,6 +224,7 @@ tensor_method_func = [ #noqa 'histogram', 'mv', 'matrix_power', + 'svd', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 74d9876cdd..40dfd32b50 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -943,6 +943,70 @@ def mv(x, vec, name=None): return out +def svd(x, full_matrices=False, name=None): + r""" + Computes the singular value decomposition of one + matrix or batches of regular matrice. + Args: + x (Tensor): The input tensor. Its shape should be `[..., N, M]`, + where ... is zero or more batch dimensions. N and M can be arbitraty + positive number. Note that if x is sigular matrices, the grad is numerical + instability. The data type of x should be float32 or float64. + + full_matrices(bool): A flag to control the behavor of svd. + If full_matrices = True, svd op will compute full U and V matrics, + which means shape of U is `[..., N, N]`, shape of V is `[..., M, M]`. + If full_matrices = False, svd op will use a economic method to store U and V. + which means shape of U is `[..., N, K]`, shape of V is `[..., M, K]` + + Returns: + Tensor: Tensor U, the shape of U is controlled by full_matrices flag. + Tensor: Tensor S, the singular value of X. the shape of S is [..., K] + Tensor: Tensor VH, the conjugate transpose of V. the shape of V is controlled by full_matrices flag. + + import numpy as np + + x = paddle.to_tensor([[1.0, 2.0], [1.0, 3.0], [4.0, 6.0]]).astype('float64') + x = x.reshape([3, 2]) + u, s, vt = paddle.linalg.svd(x) + print (u) + print (s) + print (vt) + + #U = [[ 0.27364809, -0.21695147 ], + # [ 0.37892198, -0.87112408 ], + # [ 0.8840446 , 0.44053933 ]] + + #S = [8.14753743, 0.78589688] + + #VT= [[ 0.51411221, 0.85772294], + # [ 0.85772294, -0.51411221]] + + # one can verify : U * S * VT = X ; + # U * UH = I ; + # V * VH = I + """ + + if in_dygraph_mode(): + return _C_ops.svd(x, 'full_matrices', full_matrices) + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd') + check_type(full_matrices, 'full_matrices', bool, 'svd') + helper = LayerHelper('svd', **locals()) + u = helper.create_variable_for_type_inference(dtype=x.dtype) + vh = helper.create_variable_for_type_inference(dtype=x.dtype) + s = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = dict() + attrs['full_matrices'] = full_matrices + helper.append_op( + type='svd', + inputs={'X': [x]}, + outputs={'U': u, + 'VH': vh, + 'S': s}, + attr=attrs, ) + return u, s, vh + + def matrix_power(x, n, name=None): r""" Computes the n-th power of a square matrix or a batch of square matrices. -- GitLab