diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc index 6c250675b629610944a167c3be476af2d9a317c8..7f9fccddf729a4e5479e477514cb6e2ad82493d2 100644 --- a/paddle/fluid/operators/svd_op.cc +++ b/paddle/fluid/operators/svd_op.cc @@ -17,8 +17,10 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/infermeta/unary.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -26,55 +28,9 @@ namespace paddle { namespace operators { -using DDim = framework::DDim; -static DDim UDDim(const DDim& x_dim, int k) { - // get x_dim and return the ddim of U - auto x_vec = vectorize(x_dim); - x_vec[x_vec.size() - 1] = k; - return phi::make_ddim(x_vec); -} -static DDim VHDDim(const DDim& x_dim, int k) { - // get x_dim and return the ddim of U - auto x_vec = vectorize(x_dim); - x_vec[x_vec.size() - 2] = k; - return phi::make_ddim(x_vec); -} -static DDim SDDim(const DDim& x_dim, int k) { - // get x_dim and return the ddim of U - auto x_vec = vectorize(x_dim); - x_vec[x_vec.size() - 2] = k; - x_vec.erase(x_vec.end() - 1); // rank - 1 - return phi::make_ddim(x_vec); -} - class SvdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd"); - OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd"); - OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd"); - OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd"); - - auto in_dims = ctx->GetInputDim("X"); - int x_rank = in_dims.size(); - PADDLE_ENFORCE_GE(in_dims.size(), - 2, - platform::errors::InvalidArgument( - "the rank of input must greater than 2")); - int m = in_dims[x_rank - 2]; - int n = in_dims[x_rank - 1]; - int k = std::min(m, n); - const bool full_uv = ctx->Attrs().Get("full_matrices"); - ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m)); - ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n)); - ctx->SetOutputDim("S", SDDim(in_dims, k)); - - ctx->ShareLoD("X", /*->*/ "U"); - ctx->ShareLoD("X", /*->*/ "VH"); - ctx->ShareLoD("X", /*->*/ "S"); - } }; class SvdOpMaker : public framework::OpProtoAndCheckerMaker { @@ -159,10 +115,15 @@ class SvdGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(svd, + SvdInferShapeFunctor, + PD_INFER_META(phi::SvdInferMeta)); + REGISTER_OPERATOR(svd, ops::SvdOp, ops::SvdOpMaker, ops::SvdGradMaker, - ops::SvdGradMaker); + ops::SvdGradMaker, + SvdInferShapeFunctor); REGISTER_OPERATOR(svd_grad, ops::SvdGradOp); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a7d8f5b33889e593e0360ad523acd99ef7a2bae9..2a1b307c2b02bcc7e5bea9d6b1ee719613bfd896 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2140,6 +2140,15 @@ data_type : x backward : sum_grad +- api : svd + args : (Tensor x, bool full_metrices) + output : Tensor(u), Tensor(s), Tensor(vh) + infer_meta : + func : SvdInferMeta + kernel : + func : svd + backward : svd_grad + # The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later - api : swish args : (Tensor x, float beta=1.0) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 65952fc6806a3bd62d6b99643c33e2027d908e4e..8bbdc5766b4c1fac87bdbbaccb68d9e34a86939f 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2133,6 +2133,17 @@ output : Tensor(grad_grad_x_grad) invoke : sum_grad(grad_grad_x, grad_grad_out_grad, dims, keep_dim, reduce_all, grad_grad_x_grad) +- backward_api : svd_grad + forward : svd (Tensor x, bool full) -> Tensor(u), Tensor(s), Tensor(vh) + args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : svd_grad + optional: u_grad, vh_grad, s_grad + - backward_api : swish_grad forward : swish (Tensor x, float beta=1.0) -> Tensor(out) args : (Tensor x, Tensor out_grad, float bete=1.0) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 307cd0ad296e871a85ce48169cefaf3194d69381..80f5da682957db7215b3f116f5e75cdaecc833ba 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2715,6 +2715,53 @@ void SumRawInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void SvdInferMeta(const MetaTensor& x, + bool full_matrices, + MetaTensor* u, + MetaTensor* s, + MetaTensor* vh) { + auto UDDim = [](const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 1] = k; + return phi::make_ddim(x_vec); + }; + + auto VHDDim = [](const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + return phi::make_ddim(x_vec); + }; + + auto SDDim = [](const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + x_vec.erase(x_vec.end() - 1); // rank - 1 + return phi::make_ddim(x_vec); + }; + + auto in_dims = x.dims(); + int x_rank = in_dims.size(); + PADDLE_ENFORCE_GE( + in_dims.size(), + 2, + phi::errors::InvalidArgument("the rank of input must greater than 2")); + int m = in_dims[x_rank - 2]; + int n = in_dims[x_rank - 1]; + int k = std::min(m, n); + u->set_dims(!full_matrices ? UDDim(in_dims, k) : UDDim(in_dims, m)); + vh->set_dims(!full_matrices ? VHDDim(in_dims, k) : VHDDim(in_dims, n)); + s->set_dims(SDDim(in_dims, k)); + u->share_lod(x); + vh->share_lod(x); + s->share_lod(x); + u->set_dtype(x.dtype()); + vh->set_dtype(x.dtype()); + s->set_dtype(x.dtype()); +} + void TemporalShiftInferMeta(const MetaTensor& x, int seg_num, float shift_ratio, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 03f7b09fc80fc98711f5cd8504c234ca09e60b25..a9cb1c0b610b97e0bcae552f18538e2b28466c01 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -387,6 +387,12 @@ void SumRawInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); +void SvdInferMeta(const MetaTensor& x, + bool full_matrices, + MetaTensor* u, + MetaTensor* s, + MetaTensor* vh); + void TemporalShiftInferMeta(const MetaTensor& x, int seg_num, float shift_ratio, diff --git a/paddle/phi/kernels/impl/svd_grad_kernel_impl.h b/paddle/phi/kernels/impl/svd_grad_kernel_impl.h index f87a8910ebe3e8f52f2ec332968e73112c74714a..ee7cab217893b79c317e5b91aa936fdd958fefd3 100644 --- a/paddle/phi/kernels/impl/svd_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/svd_grad_kernel_impl.h @@ -71,9 +71,9 @@ void SvdGradKernel(const Context& dev_ctx, const DenseTensor& u, const DenseTensor& vh, const DenseTensor& s, - const DenseTensor& u_grad, - const DenseTensor& vh_grad, - const DenseTensor& s_grad, + const paddle::optional& u_grad, + const paddle::optional& vh_grad, + const paddle::optional& s_grad, bool full_matrices, DenseTensor* x_grad) { const auto& dX = *x_grad; @@ -87,15 +87,33 @@ void SvdGradKernel(const Context& dev_ctx, dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {}); VH = SliceKernel( dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); - dU = SliceKernel( - dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {}); - dVH = SliceKernel( - dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {}); + if (u_grad.get_ptr() != nullptr) { + dU = SliceKernel(dev_ctx, + *(u_grad.get_ptr()), + {u.dims().size() - 1}, + {0}, + {k}, + {1}, + {}); + } + if (vh_grad.get_ptr() != nullptr) { + dVH = SliceKernel(dev_ctx, + *(vh_grad.get_ptr()), + {vh.dims().size() - 2}, + {0}, + {k}, + {1}, + {}); + } } else { U = u; VH = vh; - dU = u_grad; - dVH = vh_grad; + if (u_grad.get_ptr() != nullptr) { + dU = *(u_grad.get_ptr()); + } + if (vh_grad.get_ptr() != nullptr) { + dVH = *(vh_grad.get_ptr()); + } } auto s_inverse = Pow(dev_ctx, s, -1); auto s_square = Pow(dev_ctx, s, 2); @@ -106,19 +124,17 @@ void SvdGradKernel(const Context& dev_ctx, F, Diag(dev_ctx, Infinits(dev_ctx, {k}), 0, 0)); F = Pow(dev_ctx, F, -1); - DenseTensor sigma_term; - DenseTensor u_term; - DenseTensor v_term; + DenseTensor sigma_term = Fill(dev_ctx, {1}, 0.0); + DenseTensor u_term = Fill(dev_ctx, {1}, 0.0); + DenseTensor v_term = Fill(dev_ctx, {1}, 0.0); - // if (ctx.HasInput(framework::GradVarName("S"))) - { - const DenseTensor& gS = s_grad; + if (s_grad.get_ptr() != nullptr) { + const DenseTensor& gS = *(s_grad.get_ptr()); sigma_term = Multiply(dev_ctx, Unsqueeze(gS, -2), U); sigma_term = Matmul(dev_ctx, sigma_term, VH); } - // if (ctx.HasInput(framework::GradVarName("U"))) { - { + if (u_grad.get_ptr() != nullptr) { auto UTG = Matmul(dev_ctx, U, dU, true, false); auto GTU = Matmul(dev_ctx, dU, U, true, false); u_term = Multiply( @@ -141,10 +157,7 @@ void SvdGradKernel(const Context& dev_ctx, } u_term = Matmul(dev_ctx, u_term, VH); } - // } - - // if (ctx.HasInput(framework::GradVarName("VH"))) { - { + if (vh_grad.get_ptr() != nullptr) { auto UTG = Matmul(dev_ctx, VH, dVH, false, true); auto GTU = Matmul(dev_ctx, dVH, VH, false, true); v_term = Multiply( diff --git a/paddle/phi/kernels/svd_grad_kernel.h b/paddle/phi/kernels/svd_grad_kernel.h index 474fd6ff03ddf6759ee003109d45b79d429ef17d..66331a71912859627c832aacc3152bd243904197 100644 --- a/paddle/phi/kernels/svd_grad_kernel.h +++ b/paddle/phi/kernels/svd_grad_kernel.h @@ -20,13 +20,13 @@ namespace phi { template void SvdGradKernel(const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& U, - const DenseTensor& VH, - const DenseTensor& S, - const DenseTensor& U_grad, - const DenseTensor& VH_grad, - const DenseTensor& S_grad, + const DenseTensor& x, + const DenseTensor& u, + const DenseTensor& vh, + const DenseTensor& s, + const paddle::optional& u_grad, + const paddle::optional& vh_grad, + const paddle::optional& s_grad, bool full_matrices, DenseTensor* X_grad); } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_svd_op.py b/python/paddle/fluid/tests/unittests/test_svd_op.py index ef9bbae6b81ddfd307c1ba4a8c33c0c177b5cc45..b3cd48b05c0b2d3ac05c5eca7a4594ae9488c718 100644 --- a/python/paddle/fluid/tests/unittests/test_svd_op.py +++ b/python/paddle/fluid/tests/unittests/test_svd_op.py @@ -29,6 +29,7 @@ class TestSvdOp(OpTest): def setUp(self): paddle.enable_static() + self.python_api = paddle.linalg.svd self.generate_input() self.generate_output() self.op_type = "svd" @@ -55,7 +56,7 @@ class TestSvdOp(OpTest): self._output_data = np.linalg.svd(self._input_data) def test_check_output(self): - self.check_output(no_check_set=['U', 'VH']) + self.check_output(no_check_set=['U', 'VH'], check_eager=True) def test_svd_forward(self): """ u matmul diag(s) matmul vt must become X @@ -75,13 +76,19 @@ class TestSvdOp(OpTest): paddle.enable_static() def check_S_grad(self): - self.check_grad(['X'], ['S'], numeric_grad_delta=0.001) + self.check_grad(['X'], ['S'], + numeric_grad_delta=0.001, + check_eager=True) def check_U_grad(self): - self.check_grad(['X'], ['U'], numeric_grad_delta=0.001) + self.check_grad(['X'], ['U'], + numeric_grad_delta=0.001, + check_eager=True) def check_V_grad(self): - self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001) + self.check_grad(['X'], ['VH'], + numeric_grad_delta=0.001, + check_eager=True) def test_check_grad(self): """ diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7e7f95d17a38f13ea42e516883620191d0e91c97..a77d6b5a2ad92a36286a690543baaddd5397ddde 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1857,8 +1857,9 @@ def svd(x, full_matrices=False, name=None): # U * UH == I # V * VH == I """ - - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_svd(x, full_matrices) + if _in_legacy_dygraph(): return _C_ops.svd(x, 'full_matrices', full_matrices) check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd') check_type(full_matrices, 'full_matrices', bool, 'svd')