未验证 提交 657b6742 编写于 作者: Y Yulong Ao 提交者: GitHub

Add the backward support for QR (#38824)

* Add the backward support for QR

* Remove unnecessary comments
上级 953638e0
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
......@@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel<T> {
q_data = q.mutable_data<math::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(math::Real<T>)));
memset(q_data, 0, size_t(batch_size * m * k * sizeof(math::Real<T>)));
}
auto* r_data = r.mutable_data<math::Real<T>>(
context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(math::Real<T>)));
// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
......@@ -126,8 +129,124 @@ template <typename DeviceContext, typename T>
class QrGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"QR doesn't have the backward kernel now and will be supported soon."));
const framework::Tensor& Q = *ctx.Input<framework::Tensor>("Q");
const framework::Tensor& R = *ctx.Input<framework::Tensor>("R");
// Use a different name A instead of X
const framework::Tensor& A = *ctx.Input<framework::Tensor>("X");
const framework::Tensor& dQ =
*ctx.Input<framework::Tensor>(framework::GradVarName("Q"));
const framework::Tensor& dR =
*ctx.Input<framework::Tensor>(framework::GradVarName("R"));
// Use a different name dA instead of dX
framework::Tensor& dA =
*ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dA.mutable_data<math::Real<T>>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0));
auto dito = math::DeviceIndependenceTensorOperations<DeviceContext, T>(ctx);
std::string mode = ctx.Attr<std::string>("mode");
bool compute_q, reduced;
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
if (!compute_q) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='r'."));
}
auto a_dims = A.dims();
int a_rank = a_dims.size();
int m = a_dims[a_rank - 2];
int n = a_dims[a_rank - 1];
if ((m > n) && (!reduced)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='complete' and "
"nrows > ncols."));
}
// m >= n case
auto m_gt_n_case = [](
const framework::ExecutionContext& ctx,
math::DeviceIndependenceTensorOperations<DeviceContext, T>& dito,
const Tensor& dQ, const Tensor& dR, const Tensor& A, const Tensor& Q,
const Tensor& R) -> framework::Tensor {
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable
// Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// dR^H
framework::Tensor R_term;
if (ctx.HasInput(framework::GradVarName("R"))) {
R_term = dito.Matmul(R, dito.Transpose(dR));
} else {
R_term = dito.Fill(framework::vectorize<int>(R.dims()), 0);
}
// dQ^H * Q
framework::Tensor Q_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
Q_term = dito.Matmul(dito.Transpose(dQ), Q);
} else {
Q_term = dito.Fill(framework::vectorize<int>(R.dims()), 0);
}
framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term);
// Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true);
framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true);
framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1));
framework::Tensor rhs_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
rhs_term = dito.Add(dQ, dito.Matmul(Q, M));
} else {
rhs_term = dito.Matmul(Q, M);
}
// dA * R^H = rhs_term
auto dA =
dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))),
dito.Transpose(rhs_term),
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
return dito.Transpose(dA);
};
if (m >= n) {
auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
} else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA.mutable_data<math::Real<T>>(ctx.GetPlace());
auto Y = dito.Slice(A, {-1}, {m}, {n});
auto U = dito.Slice(R, {-1}, {0}, {m});
framework::Tensor dY, dX, dV, dR_tmp, dQ_prime;
if (ctx.HasInput(framework::GradVarName("R"))) {
dV = dito.Slice(dR, {-1}, {m}, {n});
dR_tmp = dito.Slice(dR, {-1}, {0}, {m});
// Y * dV^H
dQ_prime = dito.Matmul(Y, dito.Transpose(dV));
} else {
dV = dito.Fill(framework::vectorize<int>(Y.dims()), 0);
dQ_prime = dito.Fill(framework::vectorize<int>(Q.dims()), 0);
}
if (ctx.HasInput(framework::GradVarName("Q"))) {
dQ_prime = dito.Add(dQ_prime, dQ);
}
dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U);
dY = dito.Matmul(Q, dV);
// Concatenate dX and dY to get dA.
auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
}
}
};
......
......@@ -146,6 +146,93 @@ static std::vector<int> GetBroadcastShape(InTensors ins) {
return broadcast_shape;
}
static inline framework::DDim ComputeAndCheckShapeForConcatOp(
const bool is_runtime, const std::vector<framework::DDim>& inputs_dims,
const size_t axis) {
const size_t n = inputs_dims.size();
auto out_dims = inputs_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
platform::errors::InvalidArgument(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i]));
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
if (is_runtime) {
out_dims[axis] += inputs_dims[i][j];
} else {
if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
out_dims[axis] = -1;
} else {
out_dims[axis] += inputs_dims[i][j];
}
}
} else {
bool check_shape =
is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0);
if (check_shape) {
// check all shape in run time
PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j],
platform::errors::InvalidArgument(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
j, i, inputs_dims[0], i, inputs_dims[i]));
}
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
out_dims[j] = inputs_dims[i][j];
}
}
}
}
return out_dims;
}
static inline int64_t ComputeAxisForConcatOp(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d", -rank,
rank, axis));
if (axis < 0) {
axis = axis + rank;
}
return axis > 0 ? axis : 0;
}
// Prepared for the broadcast operation
static std::vector<int64_t> get_broadcast_batch_portion(
std::vector<int64_t> x, std::vector<int64_t> y) {
size_t size_x = x.size();
size_t size_y = y.size();
size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
ptrdiff_t i = (ptrdiff_t)size - 1;
for (; i >= 0; --i) {
ptrdiff_t offset = size - i - 1;
ptrdiff_t dim_x = size_x - offset - 1;
ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ(
(x_size == y_size || x_size == 1 || y_size == 1), true,
platform::errors::PreconditionNotMet(
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.",
x_size, y_size, i));
batchPortion[i] = x_size != 1 ? x_size : y_size;
}
return batchPortion;
}
#define DITO_TRANSPOSE_RANK_CASE(N) \
case N: { \
math::Transpose<DeviceContext, T, N> trans; \
......@@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations {
return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape);
}
framework::Tensor TriangularSolve(const framework::Tensor& x,
const framework::Tensor& y, bool upper,
bool transpose, bool unitriangular) {
framework::AttributeMap attrs;
attrs["upper"] = upper;
attrs["transpose"] = transpose;
attrs["unitriangular"] = unitriangular;
NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}});
auto x_dims = x.dims();
auto y_dims = y.dims();
auto y_dims_n = y_dims.size();
std::vector<int64_t> x_dims_vec =
paddle::framework::vectorize<int64_t>(x_dims);
std::vector<int64_t> y_dims_vec =
paddle::framework::vectorize<int64_t>(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(),
x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(),
y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2],
y_dims_vec[y_dims_n - 1]});
std::vector<int> out_shape(y_broadcast_dims.begin(),
y_broadcast_dims.end());
return CreateOpRunAndReturnTensor("triangular_solve", inputs, attrs,
out_shape);
}
framework::Tensor ConcatTwoTensors(const framework::Tensor& x,
const framework::Tensor& y, int axis) {
framework::AttributeMap attrs;
attrs["axis"] = axis;
std::vector<framework::DDim> inputs_dims({x.dims(), y.dims()});
NameInTensorMap inputs({{"X", {&x, &y}}});
size_t axis_ =
ComputeAxisForConcatOp(static_cast<int64_t>(axis),
static_cast<int64_t>(inputs_dims[0].size()));
framework::DDim out_dims =
ComputeAndCheckShapeForConcatOp(true, inputs_dims, axis_);
if (out_dims[axis_] < 0) {
out_dims[axis_] = -1;
}
std::vector<int> out_shape = framework::vectorize<int>(out_dims);
return CreateOpRunAndReturnTensor("concat", inputs, attrs, out_shape);
}
Tensor Conj(const Tensor& x) {
Tensor out;
auto* out_data = out.mutable_data<T>(x.dims(), context.GetPlace());
......
......@@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_svd_op PROPERTIES TIMEOUT 80)
set_tests_properties(test_qr_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120)
......
......@@ -21,6 +21,96 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from op_test import OpTest
class TestQrOp(OpTest):
def setUp(self):
paddle.enable_static()
np.random.seed(4)
self.op_type = "qr"
a, q, r = self.get_input_and_output()
self.inputs = {"X": a}
self.attrs = {"mode": self.get_mode()}
self.outputs = {"Q": q, "R": r}
def get_dtype(self):
return "float64"
def get_mode(self):
return "reduced"
def get_shape(self):
return (11, 11)
def get_input_and_output(self):
dtype = self.get_dtype()
shape = self.get_shape()
mode = self.get_mode()
assert mode != "r", "Cannot be backward in r mode."
a = np.random.rand(*shape).astype(dtype)
m = a.shape[-2]
n = a.shape[-1]
min_mn = min(m, n)
if mode == "reduced":
k = min_mn
else:
k = m
q_shape = list(a.shape[:-2])
q_shape.extend([m, k])
r_shape = list(a.shape[:-2])
r_shape.extend([k, n])
q = np.zeros(q_shape).astype(dtype)
r = np.zeros(r_shape).astype(dtype)
batch_size = a.size // (a.shape[-1] * a.shape[-2])
for i in range(batch_size):
coord = np.unravel_index(i, a.shape[:-2])
tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode)
q[coord] = tmp_q
r[coord] = tmp_r
return a, q, r
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R'])
class TestQrOpCase1(TestQrOp):
def get_shape(self):
return (10, 12)
class TestQrOpCase2(TestQrOp):
def get_shape(self):
return (16, 15)
class TestQrOpCase3(TestQrOp):
def get_shape(self):
return (2, 12, 16)
class TestQrOpCase4(TestQrOp):
def get_shape(self):
return (3, 16, 15)
class TestQrOpCase5(TestQrOp):
def get_mode(self):
return "complete"
def get_shape(self):
return (10, 12)
class TestQrOpCase6(TestQrOp):
def get_mode(self):
return "complete"
def get_shape(self):
return (2, 10, 12)
class TestQrAPI(unittest.TestCase):
......@@ -169,5 +259,4 @@ class TestQrAPI(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册