diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index b4724eb3c83a39d182ea72e522aba65386b2d6b6..89d5d2ded15f97c8cb18cec9fa525bef20692c96 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/determinant_op.h" - #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" @@ -170,19 +171,19 @@ REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp, DeterminantGradInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(slogdeterminant, + SlogDeterminantInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, ops::SlogDeterminantOpMaker, ops::SlogDeterminantGradOpMaker, - ops::SlogDeterminantGradOpMaker); + ops::SlogDeterminantGradOpMaker, + SlogDeterminantInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(slogdeterminant_grad, + SlogDeterminantGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); REGISTER_OPERATOR(slogdeterminant_grad, - ops::SlogDeterminantGradOp) // reuse det grad op - -REGISTER_OP_CPU_KERNEL(slogdeterminant, - ops::SlogDeterminantKernel, - ops::SlogDeterminantKernel); - -REGISTER_OP_CPU_KERNEL(slogdeterminant_grad, - ops::SlogDeterminantGradKernel, - ops::SlogDeterminantGradKernel); + ops::SlogDeterminantGradOp, + SlogDeterminantGradInferShapeFunctor) // reuse det grad op diff --git a/paddle/fluid/operators/determinant_op.cu b/paddle/fluid/operators/determinant_op.cu deleted file mode 100644 index d39d65f71e18fa55902ea7540f5a1b82f4122e6b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/determinant_op.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/determinant_op.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - slogdeterminant, - ops::SlogDeterminantKernel, - ops::SlogDeterminantKernel); - -REGISTER_OP_CUDA_KERNEL( - slogdeterminant_grad, - ops::SlogDeterminantGradKernel, - ops::SlogDeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h deleted file mode 100644 index 6a8160283f8008caf2ed73e1c4dcff066b64cb56..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/determinant_op.h +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/elementwise_multiply_kernel.h" -#include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/funcs/diag_functor.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/matrix_inverse.h" -#include "paddle/phi/kernels/funcs/unsqueeze.h" -#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h" -#include "paddle/phi/kernels/impl/determinant_kernel_impl.h" -#include "paddle/phi/kernels/matmul_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -T sign(T val) { - return static_cast(T(0) < val) - (val < T(0)); -} - -template -struct SlogDeterminantFunctor { - void operator()(const Tensor& input, - const framework::ExecutionContext ctx, - int64_t rank, - int64_t batch_count, - Tensor* output) { - std::vector input_vec; - std::vector sign_vec; - std::vector log_vec; - std::vector output_vec; - framework::TensorToVector(input, ctx.device_context(), &input_vec); - for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel - auto begin_iter = input_vec.begin() + i * rank * rank; - auto end_iter = input_vec.begin() + (i + 1) * rank * rank; - std::vector sub_vec(begin_iter, - end_iter); // get every square matrix data - typename phi::detail::EigenMatrix::MatrixType matrix(rank, rank); - for (int64_t i = 0; i < rank; ++i) { - for (int64_t j = 0; j < rank; ++j) { - matrix(i, j) = sub_vec[rank * i + j]; - } - } - VLOG(2) << "det value: " << matrix.determinant(); - VLOG(2) << "matrix val: " << matrix; - auto det_val = matrix.determinant(); - sign_vec.push_back(sign(det_val)); - det_val >= 0 - ? log_vec.push_back(std::log(det_val)) - : log_vec.push_back(std::log(std::abs( - det_val))); // for computing log value of a negative value. - } - // merge sign_vec and log_vec as final output_vec - output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); - output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); - framework::TensorFromVector(output_vec, output); - } -}; - -template -class SlogDeterminantKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto input_dim = vectorize(input->dims()); - auto input_dim_size = input_dim.size(); - auto* output = context.Output("Out"); - - auto batch_count = phi::detail::GetBatchCount(input->dims()); - VLOG(2) << "input dim:" << input->dims(); - PADDLE_ENFORCE_GE( - input_dim_size, - 2, - platform::errors::InvalidArgument( - "the input matrix dimension size should greater than 2.")); - PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1], - input_dim[input_dim_size - 2], - platform::errors::InvalidArgument( - "the input matrix should be square matrix.")); - auto rank = input_dim[input_dim_size - 1]; // square matrix length - SlogDeterminantFunctor()(*input, context, rank, batch_count, output); - std::vector output_dim_vec(input_dim.begin(), input_dim.end() - 2); - if (input_dim.size() == static_cast(2)) { - // when input is a two-dimension matrix, The det value is a number. - output_dim_vec = {1}; - } - output_dim_vec.insert(output_dim_vec.begin(), - 2); // make the output dims as same as numpy - auto output_dims = phi::make_ddim(output_dim_vec); - output->Resize(output_dims); - VLOG(2) << "output dim:" << output->dims(); - } -}; - -template -class SlogDeterminantGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto& orig_dev_ctx = context.template device_context(); - const auto* input = context.Input("Input"); - const auto* slogdet = context.Input("Out"); - const auto* grad = - context.Input(framework::GradVarName("Out")); - auto* dslogdet = - context.Output(framework::GradVarName("Input")); - - PADDLE_ENFORCE_EQ(grad->dims()[0], - 2, - platform::errors::InvalidArgument( - "The grad tensor of SlogDet should contain two" - " grad: sign and absslogdet, but here %ld.", - grad->dims()[0])); - if (input->dims().size() > 2) { - PADDLE_ENFORCE_EQ( - grad->dims().size() + 1, - input->dims().size(), - platform::errors::InvalidArgument( - "The grad tensor of slogdet dims size should 1 less than" - " input tensor's, but here differ %d", - input->dims().size() - grad->dims().size())); - } - - auto& dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE&>( - orig_dev_ctx); - - // Check Whether the matrix is invertible - // (matrix A not invertible) == (absslogdet(A)=0) - auto slogdet_vec = slogdet->Split(1, 0); - auto absslogdet_val = slogdet_vec[0]; - if (!phi::detail::CheckMatrixInvertible< - T, - typename framework::ConvertToPhiContext::TYPE>( - dev_ctx, &absslogdet_val)) { - // The matrix is not invertible - VLOG(3) << "The input matrix not invertible!"; - dslogdet->Resize(input->dims()); - phi::Full(dev_ctx, - phi::vectorize(input->dims()), - std::numeric_limits::quiet_NaN(), - dslogdet); - return; - } - - // The matrix is invertible - // let sl|A| = SlogDeterminant(A) - // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf - // we set dsl|A| = unsqueeze(dslA, [-1, -2]) * - // inverse(A).conj().transpose(-2, -1) - - // First: inverse(A) - framework::Tensor inverse_A; - // A must be square matrices! - inverse_A.Resize(input->dims()); - inverse_A.mutable_data(context.GetPlace()); - - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(orig_dev_ctx, *input, &inverse_A); - - VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); - - // Second: inverse(A).conj() - auto conj_inverse_A = phi::Conj(dev_ctx, inverse_A); - - VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); - - // Third: inverse(A).conj().transpose(-2, -1) - framework::Tensor transpose_inverse_A = - phi::TransposeLast2Dim(dev_ctx, conj_inverse_A); - VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " - << transpose_inverse_A.dims(); - - // Fourth: split grad value to [sign_grad, absslogdet_grad] - auto grad_vec = grad->Split(1, 0); - auto det_grad = grad_vec[1]; - - // remmove useless first dimension - int det_grad_size = det_grad.dims().size(); - std::vector det_grad_vec; - for (int i = 1; i < det_grad_size; ++i) { - det_grad_vec.emplace_back(det_grad.dims()[i]); - } - det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); - - // Fifth: unsqueeze(dslA, [-1, -2]) - auto unsqueeze1 = phi::funcs::Unsqueeze(det_grad, -1); - auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2); - VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); - - // Finally: unsqueeze(dslA) * inverse(A) - auto res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); - VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); - - framework::TensorCopy(res, context.GetPlace(), dslogdet); - dslogdet->Resize(input->dims()); - VLOG(3) << "dsl|A| dims: " << dslogdet->dims(); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 6408b908130250bd770e824478ea12f01a909f79..a7d8f5b33889e593e0360ad523acd99ef7a2bae9 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2029,6 +2029,15 @@ func : slice backward : slice_grad +- api : slogdet + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : slogdeterminant + backward : slogdet_grad + # soft_shrink - api : soft_shrink args : (Tensor x, float lambda) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index ba8c78edd97ca6c8f06ed065c7f90978f1a290f2..65952fc6806a3bd62d6b99643c33e2027d908e4e 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1946,6 +1946,16 @@ backward : slice_double_grad no_need_buffer : input +- backward_api : slogdet_grad + forward : slogdet (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : slogdeterminant_grad + - backward_api : soft_shrink_grad forward : soft_shrink (Tensor x, float lambda) -> Tensor(out) args : (Tensor x, Tensor out_grad, float lambda) diff --git a/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc b/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..0854895f0c1c66b1c6f948cdaeab1ea78d4ca90f --- /dev/null +++ b/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slogdeterminant_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(slogdeterminant_grad, + CPU, + ALL_LAYOUT, + phi::SlogDeterminantGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc b/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bd9f0296c62cb001268b4347f323b9a7b0c0add --- /dev/null +++ b/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slogdeterminant_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h" + +PD_REGISTER_KERNEL(slogdeterminant, + CPU, + ALL_LAYOUT, + phi::SlogDeterminantKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu b/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..153a97fa7a50ec13b872de6fea0e3cfe7f2b84fc --- /dev/null +++ b/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slogdeterminant_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(slogdeterminant_grad, + GPU, + ALL_LAYOUT, + phi::SlogDeterminantGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..e94dc117fb96e292d2e9f992507a8ac252d16b08 --- /dev/null +++ b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slogdeterminant_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h" + +PD_REGISTER_KERNEL(slogdeterminant, + GPU, + ALL_LAYOUT, + phi::SlogDeterminantKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h b/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9f629ab4bd161002f0478dfef8eb007119deb83f --- /dev/null +++ b/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h" +#include "paddle/phi/kernels/slogdeterminant_grad_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void SlogDeterminantGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + PADDLE_ENFORCE_EQ( + out_grad.dims()[0], + 2, + errors::InvalidArgument("The grad tensor of SlogDet should contain two" + " grad: sign and absslogdet, but here %ld.", + out_grad.dims()[0])); + if (x.dims().size() > 2) { + PADDLE_ENFORCE_EQ( + out_grad.dims().size() + 1, + x.dims().size(), + errors::InvalidArgument( + "The grad tensor of slogdet dims size should 1 less than" + " input tensor's, but here differ %d", + x.dims().size() - out_grad.dims().size())); + } + + // Check Whether the matrix is invertible + // (matrix A not invertible) == (absslogdet(A)=0) + auto slogdet_vec = out.Split(1, 0); + auto absslogdet_val = slogdet_vec[0]; + if (!detail::CheckMatrixInvertible(dev_ctx, &absslogdet_val)) { + // The matrix is not invertible + VLOG(3) << "The input matrix not invertible!"; + x_grad->Resize(x.dims()); + phi::Full(dev_ctx, + phi::vectorize(x.dims()), + std::numeric_limits::quiet_NaN(), + x_grad); + return; + } + + // The matrix is invertible + // let sl|A| = SlogDeterminant(A) + // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf + // we set dsl|A| = unsqueeze(dslA, [-1, -2]) * + // inverse(A).conj().transpose(-2, -1) + + // First: inverse(A) + DenseTensor inverse_A; + // A must be square matrices! + inverse_A.Resize(x.dims()); + dev_ctx.template Alloc(&inverse_A); + + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, x, &inverse_A); + + VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); + + // Second: inverse(A).conj() + auto conj_inverse_A = phi::Conj(dev_ctx, inverse_A); + + VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); + + // Third: inverse(A).conj().transpose(-2, -1) + DenseTensor transpose_inverse_A = + phi::TransposeLast2Dim(dev_ctx, conj_inverse_A); + VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " + << transpose_inverse_A.dims(); + + // Fourth: split grad value to [sign_grad, absslogdet_grad] + auto grad_vec = out_grad.Split(1, 0); + auto det_grad = grad_vec[1]; + + // remmove useless first dimension + int det_grad_size = det_grad.dims().size(); + std::vector det_grad_vec; + for (int i = 1; i < det_grad_size; ++i) { + det_grad_vec.emplace_back(det_grad.dims()[i]); + } + det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); + + // Fifth: unsqueeze(dslA, [-1, -2]) + auto unsqueeze1 = phi::funcs::Unsqueeze(det_grad, -1); + auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2); + VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); + + // Finally: unsqueeze(dslA) * inverse(A) + auto res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); + VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); + + phi::Copy(dev_ctx, res, dev_ctx.GetPlace(), false, x_grad); + x_grad->Resize(x.dims()); + VLOG(3) << "dsl|A| dims: " << x_grad->dims(); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h b/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..a6e5060385e5bdd0a35f4163da6e2cbe5d302eae --- /dev/null +++ b/paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/impl/determinant_kernel_impl.h" +#include "paddle/phi/kernels/slogdeterminant_kernel.h" + +namespace phi { + +template +T sign(T val) { + return static_cast(T(0) < val) - (val < T(0)); +} + +template +struct SlogDeterminantFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + int64_t rank, + int64_t batch_count, + DenseTensor* output) { + std::vector input_vec; + std::vector sign_vec; + std::vector log_vec; + std::vector output_vec; + paddle::framework::TensorToVector(input, dev_ctx, &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector sub_vec(begin_iter, + end_iter); // get every square matrix data + typename detail::EigenMatrix::MatrixType matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = sub_vec[rank * i + j]; + } + } + VLOG(2) << "det value: " << matrix.determinant(); + VLOG(2) << "matrix val: " << matrix; + auto det_val = matrix.determinant(); + sign_vec.push_back(sign(det_val)); + det_val >= 0 + ? log_vec.push_back(std::log(det_val)) + : log_vec.push_back(std::log(std::abs( + det_val))); // for computing log value of a negative value. + } + // merge sign_vec and log_vec as final output_vec + output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); + output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); + paddle::framework::TensorFromVector(output_vec, output); + } +}; + +template +void SlogDeterminantKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto input_dim = vectorize(x.dims()); + auto input_dim_size = input_dim.size(); + + auto batch_count = detail::GetBatchCount(x.dims()); + VLOG(2) << "input dim:" << x.dims(); + PADDLE_ENFORCE_GE( + input_dim_size, + 2, + errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); + PADDLE_ENFORCE_EQ( + input_dim[input_dim_size - 1], + input_dim[input_dim_size - 2], + errors::InvalidArgument("the input matrix should be square matrix.")); + auto rank = input_dim[input_dim_size - 1]; // square matrix length + SlogDeterminantFunctor()(dev_ctx, x, rank, batch_count, out); + std::vector output_dim_vec(input_dim.begin(), input_dim.end() - 2); + if (input_dim.size() == static_cast(2)) { + // when input is a two-dimension matrix, The det value is a number. + output_dim_vec = {1}; + } + output_dim_vec.insert(output_dim_vec.begin(), + 2); // make the output dims as same as numpy + auto output_dims = phi::make_ddim(output_dim_vec); + out->Resize(output_dims); + VLOG(2) << "output dim:" << out->dims(); +} + +} // namespace phi diff --git a/paddle/phi/kernels/slogdeterminant_grad_kernel.h b/paddle/phi/kernels/slogdeterminant_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..23bc12afda469fef98a469602b522c80c6510c69 --- /dev/null +++ b/paddle/phi/kernels/slogdeterminant_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SlogDeterminantGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/slogdeterminant_kernel.h b/paddle/phi/kernels/slogdeterminant_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..46413bd06e48b8fbf960f17defdf8a41fdb33df1 --- /dev/null +++ b/paddle/phi/kernels/slogdeterminant_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SlogDeterminantKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/slogdeterminant_sig.cc b/paddle/phi/ops/compat/slogdeterminant_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4eeca0515230af32dcba1b48257a898c17523dc --- /dev/null +++ b/paddle/phi/ops/compat/slogdeterminant_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SlogDeterminantGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "slogdeterminant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(slogdeterminant_grad, + phi::SlogDeterminantGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py index 7a799ad3776067aa246a01605710d9d3206bc5f5..8b36848521122945177bad78dcb6ddec8181b8e4 100644 --- a/python/paddle/fluid/tests/unittests/test_determinant_op.py +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -104,15 +104,18 @@ class TestSlogDeterminantOp(OpTest): def setUp(self): self.op_type = "slogdeterminant" + self.python_api = paddle.linalg.slogdet self.init_data() self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): # the slog det's grad value is always huge - self.check_grad(['Input'], ['Out'], max_relative_error=0.1) + self.check_grad(['Input'], ['Out'], + max_relative_error=0.1, + check_eager=True) def init_data(self): np.random.seed(0) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 35336228984273a22e28c9050c8517fc31d837d6..7e7f95d17a38f13ea42e516883620191d0e91c97 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1781,7 +1781,10 @@ def slogdet(x, name=None): # [-0.98610914, -0.43010661, -0.10872950]]) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_slogdet(x) + + elif paddle.in_dynamic_mode(): return _C_ops.slogdeterminant(x) check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet')