From 657b6742fc2436e920b80548bdf8a1fe20782241 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Mon, 10 Jan 2022 16:34:45 +0800 Subject: [PATCH] Add the backward support for QR (#38824) * Add the backward support for QR * Remove unnecessary comments --- paddle/fluid/operators/qr_op.h | 123 +++++++++++++++- paddle/fluid/operators/svd_helper.h | 135 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_qr_op.py | 91 +++++++++++- 4 files changed, 347 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h index 73ba52f590..65dfb4261e 100644 --- a/paddle/fluid/operators/qr_op.h +++ b/paddle/fluid/operators/qr_op.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel { q_data = q.mutable_data>( context.GetPlace(), size_t(batch_size * m * k * sizeof(math::Real))); + memset(q_data, 0, size_t(batch_size * m * k * sizeof(math::Real))); } auto* r_data = r.mutable_data>( context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + memset(r_data, 0, size_t(batch_size * k * n * sizeof(math::Real))); // Implement QR by calling Eigen for (int i = 0; i < batch_size; ++i) { @@ -126,8 +129,124 @@ template class QrGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - PADDLE_THROW(platform::errors::InvalidArgument( - "QR doesn't have the backward kernel now and will be supported soon.")); + const framework::Tensor& Q = *ctx.Input("Q"); + const framework::Tensor& R = *ctx.Input("R"); + // Use a different name A instead of X + const framework::Tensor& A = *ctx.Input("X"); + const framework::Tensor& dQ = + *ctx.Input(framework::GradVarName("Q")); + const framework::Tensor& dR = + *ctx.Input(framework::GradVarName("R")); + // Use a different name dA instead of dX + framework::Tensor& dA = + *ctx.Output(framework::GradVarName("X")); + dA.mutable_data>(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant()(dev_ctx, &dA, T(0)); + + auto dito = math::DeviceIndependenceTensorOperations(ctx); + + std::string mode = ctx.Attr("mode"); + bool compute_q, reduced; + std::tie(compute_q, reduced) = _parse_qr_mode(mode); + if (!compute_q) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The derivative of qr is not implemented when mode='r'.")); + } + + auto a_dims = A.dims(); + int a_rank = a_dims.size(); + int m = a_dims[a_rank - 2]; + int n = a_dims[a_rank - 1]; + + if ((m > n) && (!reduced)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The derivative of qr is not implemented when mode='complete' and " + "nrows > ncols.")); + } + + // m >= n case + auto m_gt_n_case = []( + const framework::ExecutionContext& ctx, + math::DeviceIndependenceTensorOperations& dito, + const Tensor& dQ, const Tensor& dR, const Tensor& A, const Tensor& Q, + const Tensor& R) -> framework::Tensor { + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable + // Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + + // dR^H + framework::Tensor R_term; + if (ctx.HasInput(framework::GradVarName("R"))) { + R_term = dito.Matmul(R, dito.Transpose(dR)); + } else { + R_term = dito.Fill(framework::vectorize(R.dims()), 0); + } + + // dQ^H * Q + framework::Tensor Q_term; + if (ctx.HasInput(framework::GradVarName("Q"))) { + Q_term = dito.Matmul(dito.Transpose(dQ), Q); + } else { + Q_term = dito.Fill(framework::vectorize(R.dims()), 0); + } + + framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term); + + // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity + framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true); + framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true); + framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1)); + + framework::Tensor rhs_term; + if (ctx.HasInput(framework::GradVarName("Q"))) { + rhs_term = dito.Add(dQ, dito.Matmul(Q, M)); + } else { + rhs_term = dito.Matmul(Q, M); + } + + // dA * R^H = rhs_term + auto dA = + dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))), + dito.Transpose(rhs_term), + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/false); + + return dito.Transpose(dA); + }; + + if (m >= n) { + auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R); + framework::TensorCopy(dA_tmp, dA.place(), &dA); + } else { + // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] + // Calculate dX and dY individually and concatenate them to get dA + dA.mutable_data>(ctx.GetPlace()); + + auto Y = dito.Slice(A, {-1}, {m}, {n}); + auto U = dito.Slice(R, {-1}, {0}, {m}); + framework::Tensor dY, dX, dV, dR_tmp, dQ_prime; + + if (ctx.HasInput(framework::GradVarName("R"))) { + dV = dito.Slice(dR, {-1}, {m}, {n}); + dR_tmp = dito.Slice(dR, {-1}, {0}, {m}); + // Y * dV^H + dQ_prime = dito.Matmul(Y, dito.Transpose(dV)); + } else { + dV = dito.Fill(framework::vectorize(Y.dims()), 0); + dQ_prime = dito.Fill(framework::vectorize(Q.dims()), 0); + } + + if (ctx.HasInput(framework::GradVarName("Q"))) { + dQ_prime = dito.Add(dQ_prime, dQ); + } + dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U); + dY = dito.Matmul(Q, dV); + // Concatenate dX and dY to get dA. + auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1); + framework::TensorCopy(dA_tmp, dA.place(), &dA); + } } }; diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 6b25846822..8d17ddec6f 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -146,6 +146,93 @@ static std::vector GetBroadcastShape(InTensors ins) { return broadcast_shape; } +static inline framework::DDim ComputeAndCheckShapeForConcatOp( + const bool is_runtime, const std::vector& inputs_dims, + const size_t axis) { + const size_t n = inputs_dims.size(); + auto out_dims = inputs_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(), + platform::errors::InvalidArgument( + "The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, inputs_dims[0], i, inputs_dims[i])); + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + if (is_runtime) { + out_dims[axis] += inputs_dims[i][j]; + } else { + if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { + out_dims[axis] = -1; + } else { + out_dims[axis] += inputs_dims[i][j]; + } + } + } else { + bool check_shape = + is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0); + if (check_shape) { + // check all shape in run time + PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j], + platform::errors::InvalidArgument( + "The %d-th dimension of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + j, i, inputs_dims[0], i, inputs_dims[i])); + } + if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { + out_dims[j] = inputs_dims[i][j]; + } + } + } + } + return out_dims; +} + +static inline int64_t ComputeAxisForConcatOp(int64_t axis, int64_t rank) { + PADDLE_ENFORCE_EQ( + axis >= -rank && axis < rank, true, + platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", -rank, + rank, axis)); + if (axis < 0) { + axis = axis + rank; + } + return axis > 0 ? axis : 0; +} + +// Prepared for the broadcast operation +static std::vector get_broadcast_batch_portion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; + + PADDLE_ENFORCE_EQ( + (x_size == y_size || x_size == 1 || y_size == 1), true, + platform::errors::PreconditionNotMet( + "The size of tensor x (%d) must match the size of tensor y " + "(%d) at non-singleton dimension %d.", + x_size, y_size, i)); + + batchPortion[i] = x_size != 1 ? x_size : y_size; + } + return batchPortion; +} + #define DITO_TRANSPOSE_RANK_CASE(N) \ case N: { \ math::Transpose trans; \ @@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations { return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape); } + framework::Tensor TriangularSolve(const framework::Tensor& x, + const framework::Tensor& y, bool upper, + bool transpose, bool unitriangular) { + framework::AttributeMap attrs; + attrs["upper"] = upper; + attrs["transpose"] = transpose; + attrs["unitriangular"] = unitriangular; + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto y_dims_n = y_dims.size(); + std::vector x_dims_vec = + paddle::framework::vectorize(x_dims); + std::vector y_dims_vec = + paddle::framework::vectorize(y_dims); + std::vector x_dims_vec_cut(x_dims_vec.begin(), + x_dims_vec.end() - 2); + std::vector y_dims_vec_cut(y_dims_vec.begin(), + y_dims_vec.end() - 2); + std::vector expand_batch_portion = + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); + std::vector y_broadcast_dims({expand_batch_portion}); + y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], + y_dims_vec[y_dims_n - 1]}); + std::vector out_shape(y_broadcast_dims.begin(), + y_broadcast_dims.end()); + return CreateOpRunAndReturnTensor("triangular_solve", inputs, attrs, + out_shape); + } + + framework::Tensor ConcatTwoTensors(const framework::Tensor& x, + const framework::Tensor& y, int axis) { + framework::AttributeMap attrs; + attrs["axis"] = axis; + std::vector inputs_dims({x.dims(), y.dims()}); + NameInTensorMap inputs({{"X", {&x, &y}}}); + size_t axis_ = + ComputeAxisForConcatOp(static_cast(axis), + static_cast(inputs_dims[0].size())); + framework::DDim out_dims = + ComputeAndCheckShapeForConcatOp(true, inputs_dims, axis_); + if (out_dims[axis_] < 0) { + out_dims[axis_] = -1; + } + std::vector out_shape = framework::vectorize(out_dims); + return CreateOpRunAndReturnTensor("concat", inputs, attrs, out_shape); + } + Tensor Conj(const Tensor& x) { Tensor out; auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9c3f9cbad5..64c247e56d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_svd_op PROPERTIES TIMEOUT 80) +set_tests_properties(test_qr_op PROPERTIES TIMEOUT 60) set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py index ea2aaf3f00..4be46837a6 100644 --- a/python/paddle/fluid/tests/unittests/test_qr_op.py +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -21,6 +21,96 @@ import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core +from op_test import OpTest + + +class TestQrOp(OpTest): + def setUp(self): + paddle.enable_static() + np.random.seed(4) + self.op_type = "qr" + a, q, r = self.get_input_and_output() + self.inputs = {"X": a} + self.attrs = {"mode": self.get_mode()} + self.outputs = {"Q": q, "R": r} + + def get_dtype(self): + return "float64" + + def get_mode(self): + return "reduced" + + def get_shape(self): + return (11, 11) + + def get_input_and_output(self): + dtype = self.get_dtype() + shape = self.get_shape() + mode = self.get_mode() + assert mode != "r", "Cannot be backward in r mode." + a = np.random.rand(*shape).astype(dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced": + k = min_mn + else: + k = m + q_shape = list(a.shape[:-2]) + q_shape.extend([m, k]) + r_shape = list(a.shape[:-2]) + r_shape.extend([k, n]) + q = np.zeros(q_shape).astype(dtype) + r = np.zeros(r_shape).astype(dtype) + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + q[coord] = tmp_q + r[coord] = tmp_r + return a, q, r + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], ['Q', 'R']) + + +class TestQrOpCase1(TestQrOp): + def get_shape(self): + return (10, 12) + + +class TestQrOpCase2(TestQrOp): + def get_shape(self): + return (16, 15) + + +class TestQrOpCase3(TestQrOp): + def get_shape(self): + return (2, 12, 16) + + +class TestQrOpCase4(TestQrOp): + def get_shape(self): + return (3, 16, 15) + + +class TestQrOpCase5(TestQrOp): + def get_mode(self): + return "complete" + + def get_shape(self): + return (10, 12) + + +class TestQrOpCase6(TestQrOp): + def get_mode(self): + return "complete" + + def get_shape(self): + return (2, 10, 12) class TestQrAPI(unittest.TestCase): @@ -169,5 +259,4 @@ class TestQrAPI(unittest.TestCase): if __name__ == "__main__": - paddle.enable_static() unittest.main() -- GitLab