未验证 提交 b91e8eec 编写于 作者: J jiangcheng 提交者: GitHub

add gradient kernel of det op and slogdet op (#36013)

* add gradient kernel of det op and slogdet op

* fix CI APPROVAL problem
上级 787273ed
......@@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DeterminantGradOp");
......@@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel {
"SlogDeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out",
"SlogDeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "SlogDeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "SlogDeterminantGradOp");
......@@ -179,7 +182,7 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(slogdeterminant_grad,
ops::DeterminantGradOp) // reuse det grad op
ops::SlogDeterminantGradOp) // reuse det grad op
REGISTER_OP_CPU_KERNEL(
slogdeterminant, ops::SlogDeterminantKernel<plat::CPUDeviceContext, float>,
......@@ -187,5 +190,5 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
slogdeterminant_grad,
ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, double>);
......@@ -14,42 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
template <typename T>
__global__ void DeterminantGrad(const size_t numel, T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < numel) {
out[tid] = static_cast<T>(1);
}
}
template <typename T>
class DeterminantGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
const T* dout_data = dout->data<T>();
auto dout_dim = vectorize(dout->dims());
auto* dx = context.Output<Tensor>(framework::GradVarName("Input"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
int64_t numel = dx->numel();
for (int64_t idx = 0; idx < numel; idx++) {
dx_data[idx] = static_cast<T>(1);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
......
......@@ -19,7 +19,11 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/matrix_inverse.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
......@@ -48,11 +52,10 @@ class EigenMatrix<double> {
inline int64_t GetBatchCount(const framework::DDim dims) {
int64_t batch_count = 1;
auto dim_size = dims.size();
PADDLE_ENFORCE_GT(dim_size, 2,
platform::errors::InvalidArgument(
"To get the number of batch square matrices, "
"the size of dimension should greater than 2.",
dim_size));
PADDLE_ENFORCE_GE(
dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
......@@ -77,7 +80,7 @@ struct DeterminantFunctor {
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
Eigen::MatrixXf matrix(rank, rank);
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
......@@ -109,41 +112,169 @@ class DeterminantKernel : public framework::OpKernel<T> {
"the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
auto output_dims =
framework::slice_ddim(input->dims(), 0, input_dim_size - 2);
if (input_dim_size > 2) {
auto output_dims =
framework::slice_ddim(input->dims(), 0, input_dim_size - 2);
output->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
output->Resize({1});
}
VLOG(2) << "output dim:" << output->dims();
}
};
template <typename T>
struct FoundZeroFunctor {
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
: x_(x), numel_(numel), res_(res) {}
HOSTDEVICE void operator()(size_t idx) const {
if (*res_ || idx >= static_cast<size_t>(numel_)) {
// founded zero number
return;
}
*res_ = (x_[idx] == static_cast<T>(0));
}
const T* x_;
int64_t numel_;
bool* res_;
};
template <typename DeviceContext, typename T>
inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx,
const framework::Tensor* det) {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto numel = det->numel();
framework::Tensor dev_tensor;
auto* data = dev_tensor.mutable_data<bool>({1}, ctx.GetPlace());
// set false
math::SetConstant<DeviceContext, bool> zero;
zero(dev_ctx, &dev_tensor, false);
// find whether zero
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
FoundZeroFunctor<T> functor(det->data<T>(), numel, data);
for_range(functor);
// copy to host
dev_ctx.Wait();
framework::Tensor cpu_tensor;
framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor);
// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto* res = cpu_tensor.data<bool>();
return !(*res);
}
template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Not support DeterminantGrad at this time."));
auto& dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ddet =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto input_dims_size = input->dims().size();
if (input_dims_size > 2) {
PADDLE_ENFORCE_EQ(
grad->dims().size() + 2, input_dims_size,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ(
grad->dims().size(), 1,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else {
// checked in forward, pass
}
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims());
ddet->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, ddet, static_cast<T>(0.0f));
return;
}
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
inverse_A.Resize(input->dims());
inverse_A.mutable_data<T>(context.GetPlace());
math::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Third: dA * |A|
auto mul_dA_detA = helper.Mul(*grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), ddet);
ddet->Resize(input->dims());
VLOG(3) << "d|A| dims: " << ddet->dims();
}
};
template <typename T>
struct SlogDeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
int rank, int batch_count, Tensor* output) {
int64_t rank, int64_t batch_count, Tensor* output) {
std::vector<T> input_vec;
std::vector<T> sign_vec;
std::vector<T> log_vec;
std::vector<T> output_vec;
framework::TensorToVector(input, ctx.device_context(), &input_vec);
for (int i = 0; i < batch_count; ++i) { // maybe can be parallel
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int i = 0; i < rank; ++i) {
for (int j = 0; j < rank; ++j) {
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
}
}
......@@ -185,6 +316,10 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
auto rank = input_dim[input_dim_size - 1]; // square matrix length
SlogDeterminantFunctor<T>()(*input, context, rank, batch_count, output);
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {1};
}
output_dim_vec.insert(output_dim_vec.begin(),
2); // make the output dims as same as numpy
auto output_dims = framework::make_ddim(output_dim_vec);
......@@ -197,8 +332,103 @@ template <typename DeviceContext, typename T>
class SlogDeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Not support SlogDeterminantGrad at this time."));
auto& dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* slogdet = context.Input<framework::Tensor>("Out");
const auto* grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dslogdet =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
PADDLE_ENFORCE_EQ(grad->dims()[0], 2,
platform::errors::InvalidArgument(
"The grad tensor of SlogDet should contain two"
" grad: sign and absslogdet, but here %ld.",
grad->dims()[0]));
if (input->dims().size() > 2) {
PADDLE_ENFORCE_EQ(
grad->dims().size() + 1, input->dims().size(),
platform::errors::InvalidArgument(
"The grad tensor of slogdet dims size should 1 less than"
" input tensor's, but here differ %d",
input->dims().size() - grad->dims().size()));
}
// Check Whether the matrix is invertible
// (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0);
auto absslogdet_val = slogdet_vec[0];
if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims());
dslogdet->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, dslogdet, std::numeric_limits<T>::quiet_NaN());
return;
}
// The matrix is invertible
// let sl|A| = SlogDeterminant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set dsl|A| = unsqueeze(dslA, [-1, -2]) *
// inverse(A).conj().transpose(-2, -1)
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(context);
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
inverse_A.Resize(input->dims());
inverse_A.mutable_data<T>(context.GetPlace());
math::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).conj()
framework::Tensor conj_inverse_A;
conj_inverse_A.Resize(inverse_A.dims());
auto numel = input->numel();
auto* conj_data = conj_inverse_A.mutable_data<T>(context.GetPlace(),
size_t(numel * sizeof(T)));
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::ConjFunctor<T> functor(inverse_A.data<T>(), numel, conj_data);
for_range(functor);
VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims();
// Third: inverse(A).conj().transpose(-2, -1)
framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A);
VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Fourth: split grad value to [sign_grad, absslogdet_grad]
auto grad_vec = grad->Split(1, 0);
auto det_grad = grad_vec[1];
// remmove useless first dimension
int det_grad_size = det_grad.dims().size();
std::vector<int> det_grad_vec;
for (int i = 1; i < det_grad_size; ++i) {
det_grad_vec.emplace_back(det_grad.dims()[i]);
}
det_grad.Resize(det_grad.dims().reshape(det_grad_vec));
// Fifth: unsqueeze(dslA, [-1, -2])
auto unsqueeze1 = helper.Unsqueeze(det_grad, -1);
auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dslA) * inverse(A)
auto res = helper.Mul(unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), dslogdet);
dslogdet->Resize(input->dims());
VLOG(3) << "dsl|A| dims: " << dslogdet->dims();
}
};
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest
import paddle
import paddle.nn.functional as F
import paddle.fluid as fluid
......@@ -26,7 +26,6 @@ import paddle.tensor as tensor
paddle.enable_static()
@skip_check_grad_ci(reason="determinant grad is in progress.")
class TestDeterminantOp(OpTest):
def setUp(self):
self.init_data()
......@@ -37,11 +36,11 @@ class TestDeterminantOp(OpTest):
self.check_output()
def test_check_grad(self):
pass
self.check_grad(['Input'], ['Out'])
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(3, 3, 3, 3, 3).astype('float64')
self.case = np.random.rand(3, 3, 3, 5, 5).astype('float64')
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case)
......@@ -49,30 +48,25 @@ class TestDeterminantOp(OpTest):
class TestDeterminantOpCase1(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(3, 3, 3, 3).astype(np.float32)
self.case = np.random.rand(10, 10).astype('float32')
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case)
def test_check_grad(self):
pass
class TestDeterminantOpCase2(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(4, 2, 4, 4).astype('float64')
# not invertible matrix
self.case = np.ones([4, 2, 4, 4]).astype('float64')
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case)
def test_check_grad(self):
pass
class TestDeterminantAPI(unittest.TestCase):
def setUp(self):
self.shape = [3, 3, 3, 3]
np.random.seed(0)
self.x = np.random.rand(3, 3, 3, 3).astype(np.float32)
self.shape = [3, 3, 5, 5]
self.x = np.random.random(self.shape).astype(np.float32)
self.place = paddle.CPUPlace()
def test_api_static(self):
......@@ -96,7 +90,6 @@ class TestDeterminantAPI(unittest.TestCase):
paddle.enable_static()
@skip_check_grad_ci(reason="slogdeterminant grad is in progress.")
class TestSlogDeterminantOp(OpTest):
def setUp(self):
self.op_type = "slogdeterminant"
......@@ -107,11 +100,12 @@ class TestSlogDeterminantOp(OpTest):
self.check_output()
def test_check_grad(self):
pass
# the slog det's grad value is always huge
self.check_grad(['Input'], ['Out'], max_relative_error=0.1)
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(3, 3, 3, 3).astype('float64')
self.case = np.random.rand(4, 5, 5).astype('float64')
self.inputs = {'Input': self.case}
self.target = np.array(np.linalg.slogdet(self.case))
......@@ -126,9 +120,9 @@ class TestSlogDeterminantOpCase1(TestSlogDeterminantOp):
class TestSlogDeterminantAPI(unittest.TestCase):
def setUp(self):
self.shape = [3, 3, 3, 3]
np.random.seed(0)
self.x = np.random.rand(3, 3, 3, 3).astype(np.float32)
self.shape = [3, 3, 5, 5]
self.x = np.random.random(self.shape).astype(np.float32)
self.place = paddle.CPUPlace()
def test_api_static(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册