From b91e8eec405ea6440124149cf6d21d559c7fded1 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Fri, 24 Sep 2021 17:11:00 +0800 Subject: [PATCH] add gradient kernel of det op and slogdet op (#36013) * add gradient kernel of det op and slogdet op * fix CI APPROVAL problem --- paddle/fluid/operators/determinant_op.cc | 11 +- paddle/fluid/operators/determinant_op.cu | 36 --- paddle/fluid/operators/determinant_op.h | 262 ++++++++++++++++-- .../tests/unittests/test_determinant_op.py | 32 +-- 4 files changed, 266 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index 379a401cde6..98247fbc862 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "DeterminantGradOp"); OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "DeterminantGradOp"); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", framework::GradVarName("Input"), "DeterminantGradOp"); @@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel { "SlogDeterminantGradOp"); OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "SlogDeterminantGradOp"); - + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "SlogDeterminantGradOp"); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", framework::GradVarName("Input"), "SlogDeterminantGradOp"); @@ -179,7 +182,7 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, ops::SlogDeterminantGradOpMaker); REGISTER_OPERATOR(slogdeterminant_grad, - ops::DeterminantGradOp) // reuse det grad op + ops::SlogDeterminantGradOp) // reuse det grad op REGISTER_OP_CPU_KERNEL( slogdeterminant, ops::SlogDeterminantKernel, @@ -187,5 +190,5 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( slogdeterminant_grad, - ops::DeterminantGradKernel, - ops::DeterminantGradKernel); + ops::SlogDeterminantGradKernel, + ops::SlogDeterminantGradKernel); diff --git a/paddle/fluid/operators/determinant_op.cu b/paddle/fluid/operators/determinant_op.cu index f17d94d8052..d19d4c3d093 100644 --- a/paddle/fluid/operators/determinant_op.cu +++ b/paddle/fluid/operators/determinant_op.cu @@ -14,42 +14,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/determinant_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; -using Tensor = framework::Tensor; - -template -__global__ void DeterminantGrad(const size_t numel, T* out) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < numel) { - out[tid] = static_cast(1); - } -} - -template -class DeterminantGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* dout = context.Input(framework::GradVarName("Out")); - const T* dout_data = dout->data(); - auto dout_dim = vectorize(dout->dims()); - - auto* dx = context.Output(framework::GradVarName("Input")); - T* dx_data = dx->mutable_data(context.GetPlace()); - - int64_t numel = dx->numel(); - for (int64_t idx = 0; idx < numel; idx++) { - dx_data[idx] = static_cast(1); - } - } -}; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h index ead1262d9fe..4c17869fb5d 100644 --- a/paddle/fluid/operators/determinant_op.h +++ b/paddle/fluid/operators/determinant_op.h @@ -19,7 +19,11 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/matrix_inverse.h" +#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { @@ -48,11 +52,10 @@ class EigenMatrix { inline int64_t GetBatchCount(const framework::DDim dims) { int64_t batch_count = 1; auto dim_size = dims.size(); - PADDLE_ENFORCE_GT(dim_size, 2, - platform::errors::InvalidArgument( - "To get the number of batch square matrices, " - "the size of dimension should greater than 2.", - dim_size)); + PADDLE_ENFORCE_GE( + dim_size, 2, + platform::errors::InvalidArgument( + "the input matrix dimension size should greater than 2.")); // Cumulative multiplying each dimension until the last 2 to get the batch // count, @@ -77,7 +80,7 @@ struct DeterminantFunctor { auto end_iter = input_vec.begin() + (i + 1) * rank * rank; std::vector sub_vec(begin_iter, end_iter); // get every square matrix data - Eigen::MatrixXf matrix(rank, rank); + typename EigenMatrix::MatrixType matrix(rank, rank); for (int64_t i = 0; i < rank; ++i) { for (int64_t j = 0; j < rank; ++j) { matrix(i, j) = sub_vec[rank * i + j]; @@ -109,41 +112,169 @@ class DeterminantKernel : public framework::OpKernel { "the input matrix should be square matrix.")); auto rank = input_dim[input_dim_size - 1]; // square matrix length DeterminantFunctor()(*input, context, rank, batch_count, output); + auto output_dims = + framework::slice_ddim(input->dims(), 0, input_dim_size - 2); if (input_dim_size > 2) { - auto output_dims = - framework::slice_ddim(input->dims(), 0, input_dim_size - 2); output->Resize(output_dims); + } else { + // when input is a two-dimension matrix, The det value is a number. + output->Resize({1}); } VLOG(2) << "output dim:" << output->dims(); } }; +template +struct FoundZeroFunctor { + FoundZeroFunctor(const T* x, int64_t numel, bool* res) + : x_(x), numel_(numel), res_(res) {} + HOSTDEVICE void operator()(size_t idx) const { + if (*res_ || idx >= static_cast(numel_)) { + // founded zero number + return; + } + *res_ = (x_[idx] == static_cast(0)); + } + const T* x_; + int64_t numel_; + bool* res_; +}; + +template +inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx, + const framework::Tensor* det) { + auto& dev_ctx = ctx.template device_context(); + auto numel = det->numel(); + + framework::Tensor dev_tensor; + auto* data = dev_tensor.mutable_data({1}, ctx.GetPlace()); + + // set false + math::SetConstant zero; + zero(dev_ctx, &dev_tensor, false); + + // find whether zero + platform::ForRange for_range(dev_ctx, numel); + FoundZeroFunctor functor(det->data(), numel, data); + for_range(functor); + + // copy to host + dev_ctx.Wait(); + framework::Tensor cpu_tensor; + framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor); + + // if founded zero, the matrix is not invertible + // else the matrix is invertible + auto* res = cpu_tensor.data(); + return !(*res); +} + template class DeterminantGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Not support DeterminantGrad at this time.")); + auto& dev_ctx = context.template device_context(); + const auto* input = context.Input("Input"); + const auto* det = context.Input("Out"); + const auto* grad = + context.Input(framework::GradVarName("Out")); + auto* ddet = + context.Output(framework::GradVarName("Input")); + + auto input_dims_size = input->dims().size(); + if (input_dims_size > 2) { + PADDLE_ENFORCE_EQ( + grad->dims().size() + 2, input_dims_size, + platform::errors::InvalidArgument( + "The grad tensor of det dims size should 2 less than" + " input tensor's, but here differ %d", + input_dims_size - grad->dims().size())); + } else if (input_dims_size == 2) { + // input dims size 2 and grad dims size 1 is possible + PADDLE_ENFORCE_EQ( + grad->dims().size(), 1, + platform::errors::InvalidArgument( + "The grad tensor of det dims size should 2 less than" + " input tensor's, but here differ %d", + input_dims_size - grad->dims().size())); + } else { + // checked in forward, pass + } + + // Check Whether the matrix is invertible + // (matrix A not invertible) == (det(A)=0) + if (!CheckMatrixInvertible(context, det)) { + // The matrix is not invertible + VLOG(3) << "The input matrix not invertible!"; + ddet->Resize(input->dims()); + ddet->mutable_data(context.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, ddet, static_cast(0.0f)); + return; + } + + // The matrix is invertible + // let |A| = Determinant(A) + // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf + // we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2, + // -1) + + math::DeviceIndependenceTensorOperations helper(context); + + // First: inverse(A) + framework::Tensor inverse_A; + // A must be square matrices! + inverse_A.Resize(input->dims()); + inverse_A.mutable_data(context.GetPlace()); + + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *input, &inverse_A); + + VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); + + // Second: inverse(A).transpose(-2, -1) + framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A); + VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " + << transpose_inverse_A.dims(); + + // Third: dA * |A| + auto mul_dA_detA = helper.Mul(*grad, *det); + VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims(); + + // Fourth: unsqueeze(dA * |A|, [-1, -2]) + auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1); + auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); + + // Finally: unsqueeze(dA * |A|) * inverse(A) + auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + + VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); + + framework::TensorCopy(res, context.GetPlace(), ddet); + + ddet->Resize(input->dims()); + VLOG(3) << "d|A| dims: " << ddet->dims(); } }; template struct SlogDeterminantFunctor { void operator()(const Tensor& input, const framework::ExecutionContext ctx, - int rank, int batch_count, Tensor* output) { + int64_t rank, int64_t batch_count, Tensor* output) { std::vector input_vec; std::vector sign_vec; std::vector log_vec; std::vector output_vec; framework::TensorToVector(input, ctx.device_context(), &input_vec); - for (int i = 0; i < batch_count; ++i) { // maybe can be parallel + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel auto begin_iter = input_vec.begin() + i * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank; std::vector sub_vec(begin_iter, end_iter); // get every square matrix data typename EigenMatrix::MatrixType matrix(rank, rank); - for (int i = 0; i < rank; ++i) { - for (int j = 0; j < rank; ++j) { + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { matrix(i, j) = sub_vec[rank * i + j]; } } @@ -185,6 +316,10 @@ class SlogDeterminantKernel : public framework::OpKernel { auto rank = input_dim[input_dim_size - 1]; // square matrix length SlogDeterminantFunctor()(*input, context, rank, batch_count, output); std::vector output_dim_vec(input_dim.begin(), input_dim.end() - 2); + if (input_dim.size() == static_cast(2)) { + // when input is a two-dimension matrix, The det value is a number. + output_dim_vec = {1}; + } output_dim_vec.insert(output_dim_vec.begin(), 2); // make the output dims as same as numpy auto output_dims = framework::make_ddim(output_dim_vec); @@ -197,8 +332,103 @@ template class SlogDeterminantGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Not support SlogDeterminantGrad at this time.")); + auto& dev_ctx = context.template device_context(); + const auto* input = context.Input("Input"); + const auto* slogdet = context.Input("Out"); + const auto* grad = + context.Input(framework::GradVarName("Out")); + auto* dslogdet = + context.Output(framework::GradVarName("Input")); + + PADDLE_ENFORCE_EQ(grad->dims()[0], 2, + platform::errors::InvalidArgument( + "The grad tensor of SlogDet should contain two" + " grad: sign and absslogdet, but here %ld.", + grad->dims()[0])); + if (input->dims().size() > 2) { + PADDLE_ENFORCE_EQ( + grad->dims().size() + 1, input->dims().size(), + platform::errors::InvalidArgument( + "The grad tensor of slogdet dims size should 1 less than" + " input tensor's, but here differ %d", + input->dims().size() - grad->dims().size())); + } + + // Check Whether the matrix is invertible + // (matrix A not invertible) == (absslogdet(A)=0) + auto slogdet_vec = slogdet->Split(1, 0); + auto absslogdet_val = slogdet_vec[0]; + if (!CheckMatrixInvertible(context, &absslogdet_val)) { + // The matrix is not invertible + VLOG(3) << "The input matrix not invertible!"; + dslogdet->Resize(input->dims()); + dslogdet->mutable_data(context.GetPlace()); + math::SetConstant zero; + zero(dev_ctx, dslogdet, std::numeric_limits::quiet_NaN()); + return; + } + + // The matrix is invertible + // let sl|A| = SlogDeterminant(A) + // Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf + // we set dsl|A| = unsqueeze(dslA, [-1, -2]) * + // inverse(A).conj().transpose(-2, -1) + + math::DeviceIndependenceTensorOperations helper(context); + + // First: inverse(A) + framework::Tensor inverse_A; + // A must be square matrices! + inverse_A.Resize(input->dims()); + inverse_A.mutable_data(context.GetPlace()); + + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *input, &inverse_A); + + VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); + + // Second: inverse(A).conj() + framework::Tensor conj_inverse_A; + conj_inverse_A.Resize(inverse_A.dims()); + auto numel = input->numel(); + auto* conj_data = conj_inverse_A.mutable_data(context.GetPlace(), + size_t(numel * sizeof(T))); + + platform::ForRange for_range(dev_ctx, numel); + math::ConjFunctor functor(inverse_A.data(), numel, conj_data); + for_range(functor); + + VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); + + // Third: inverse(A).conj().transpose(-2, -1) + framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A); + VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " + << transpose_inverse_A.dims(); + + // Fourth: split grad value to [sign_grad, absslogdet_grad] + auto grad_vec = grad->Split(1, 0); + auto det_grad = grad_vec[1]; + + // remmove useless first dimension + int det_grad_size = det_grad.dims().size(); + std::vector det_grad_vec; + for (int i = 1; i < det_grad_size; ++i) { + det_grad_vec.emplace_back(det_grad.dims()[i]); + } + det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); + + // Fifth: unsqueeze(dslA, [-1, -2]) + auto unsqueeze1 = helper.Unsqueeze(det_grad, -1); + auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); + + // Finally: unsqueeze(dslA) * inverse(A) + auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); + + framework::TensorCopy(res, context.GetPlace(), dslogdet); + dslogdet->Resize(input->dims()); + VLOG(3) << "dsl|A| dims: " << dslogdet->dims(); } }; diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py index c19d44eb030..f8110bffa2f 100644 --- a/python/paddle/fluid/tests/unittests/test_determinant_op.py +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest import paddle import paddle.nn.functional as F import paddle.fluid as fluid @@ -26,7 +26,6 @@ import paddle.tensor as tensor paddle.enable_static() -@skip_check_grad_ci(reason="determinant grad is in progress.") class TestDeterminantOp(OpTest): def setUp(self): self.init_data() @@ -37,11 +36,11 @@ class TestDeterminantOp(OpTest): self.check_output() def test_check_grad(self): - pass + self.check_grad(['Input'], ['Out']) def init_data(self): np.random.seed(0) - self.case = np.random.rand(3, 3, 3, 3, 3).astype('float64') + self.case = np.random.rand(3, 3, 3, 5, 5).astype('float64') self.inputs = {'Input': self.case} self.target = np.linalg.det(self.case) @@ -49,30 +48,25 @@ class TestDeterminantOp(OpTest): class TestDeterminantOpCase1(TestDeterminantOp): def init_data(self): np.random.seed(0) - self.case = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.case = np.random.rand(10, 10).astype('float32') self.inputs = {'Input': self.case} self.target = np.linalg.det(self.case) - def test_check_grad(self): - pass - class TestDeterminantOpCase2(TestDeterminantOp): def init_data(self): np.random.seed(0) - self.case = np.random.rand(4, 2, 4, 4).astype('float64') + # not invertible matrix + self.case = np.ones([4, 2, 4, 4]).astype('float64') self.inputs = {'Input': self.case} self.target = np.linalg.det(self.case) - def test_check_grad(self): - pass - class TestDeterminantAPI(unittest.TestCase): def setUp(self): - self.shape = [3, 3, 3, 3] np.random.seed(0) - self.x = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.shape = [3, 3, 5, 5] + self.x = np.random.random(self.shape).astype(np.float32) self.place = paddle.CPUPlace() def test_api_static(self): @@ -96,7 +90,6 @@ class TestDeterminantAPI(unittest.TestCase): paddle.enable_static() -@skip_check_grad_ci(reason="slogdeterminant grad is in progress.") class TestSlogDeterminantOp(OpTest): def setUp(self): self.op_type = "slogdeterminant" @@ -107,11 +100,12 @@ class TestSlogDeterminantOp(OpTest): self.check_output() def test_check_grad(self): - pass + # the slog det's grad value is always huge + self.check_grad(['Input'], ['Out'], max_relative_error=0.1) def init_data(self): np.random.seed(0) - self.case = np.random.rand(3, 3, 3, 3).astype('float64') + self.case = np.random.rand(4, 5, 5).astype('float64') self.inputs = {'Input': self.case} self.target = np.array(np.linalg.slogdet(self.case)) @@ -126,9 +120,9 @@ class TestSlogDeterminantOpCase1(TestSlogDeterminantOp): class TestSlogDeterminantAPI(unittest.TestCase): def setUp(self): - self.shape = [3, 3, 3, 3] np.random.seed(0) - self.x = np.random.rand(3, 3, 3, 3).astype(np.float32) + self.shape = [3, 3, 5, 5] + self.x = np.random.random(self.shape).astype(np.float32) self.place = paddle.CPUPlace() def test_api_static(self): -- GitLab