未验证 提交 a1275c8b 编写于 作者: Z zhiboniu 提交者: GitHub

add lu_op backward (#38616)

上级 8d32cef8
......@@ -149,7 +149,67 @@ class LUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class LUOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("lu_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput("Pivots", this->Output("Pivots"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
class LUGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("X", 0);
auto data_type = ctx->GetInputDataType("X", 0);
ctx->SetOutputType(framework::GradVarName("X"), var_type,
framework::ALL_ELEMENTS);
ctx->SetOutputDataType(framework::GradVarName("X"), data_type,
framework::ALL_ELEMENTS);
}
};
class LUGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "lu");
OP_INOUT_CHECK(ctx->HasInput("Pivots"), "Input", "Pivots", "lu");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "lu");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
DECLARE_INPLACE_OP_INFERER(LUOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(LUGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -157,6 +217,13 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(lu, ops::LUOp, ops::LUOpMaker, ops::LUOpVarTypeInference,
ops::LUOpGradMaker<paddle::framework::OpDesc>,
ops::LUOpGradMaker<paddle::imperative::OpBase>,
ops::LUOpInplaceInferer);
REGISTER_OPERATOR(lu_grad, ops::LUGradOp, ops::LUGradOpVarTypeInference,
ops::LUGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(lu, ops::LUKernel<float>, ops::LUKernel<double>);
REGISTER_OP_CPU_KERNEL(lu_grad,
ops::LUGradKernel<plat::CPUDeviceContext, float>,
ops::LUGradKernel<plat::CPUDeviceContext, double>);
......@@ -152,5 +152,8 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lu, ops::LUCUDAKernel<float>,
ops::LUCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(lu_grad,
ops::LUGradKernel<plat::CUDADeviceContext, float>,
ops::LUGradKernel<plat::CUDADeviceContext, double>);
#endif // not PADDLE_WITH_HIP
......@@ -470,5 +470,228 @@ void Unpack_Pivot(const DeviceContext& dev_ctx, const framework::Tensor& Pivot,
}
}
template <typename DeviceContext, typename T>
class LUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto xin = ctx.Input<framework::Tensor>("X");
auto out = ctx.Input<framework::Tensor>("Out");
auto P = ctx.Input<framework::Tensor>("Pivots");
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(ctx);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto xdims = xin->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
framework::Tensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH,
grad_narrow;
LU_Unpack<DeviceContext, T>(dev_ctx, out, &L, &U);
Tensor_narrow<DeviceContext, T>(ctx, &L, &L_narrow, 0, k, 0, k);
Tensor_narrow<DeviceContext, T>(ctx, &U, &U_narrow, 0, k, 0, k);
Tensor_narrow<DeviceContext, T>(ctx, dout, &grad_narrow, 0, k, 0, k);
auto graddims = grad_narrow.dims();
Tensor_Conj<DeviceContext, T>(dev_ctx, L_narrow, &L_narrow_mH);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
L_narrow_mH = helper.Transpose(L_narrow_mH);
U_narrow_mH = helper.Transpose(U_narrow_mH);
auto LmHdims = L_narrow_mH.dims();
auto UmHdims = U_narrow_mH.dims();
framework::Tensor phi_L, phi_U, phi, psi;
phi_L.Resize(LmHdims);
phi_L.mutable_data<T>(ctx.GetPlace());
phi_U.Resize(UmHdims);
phi_U.mutable_data<T>(ctx.GetPlace());
auto mat_dim_l = math::CreateMatrixDescriptor(LmHdims, 0, false);
auto mat_dim_u = math::CreateMatrixDescriptor(UmHdims, 0, false);
auto mat_dim_g = math::CreateMatrixDescriptor(graddims, 0, false);
blas.MatMul(L_narrow_mH, mat_dim_l, grad_narrow, mat_dim_g,
static_cast<T>(1), &phi_L, static_cast<T>(0));
blas.MatMul(grad_narrow, mat_dim_g, U_narrow_mH, mat_dim_u,
static_cast<T>(1), &phi_U, static_cast<T>(0));
auto phil_rank = LmHdims.size();
auto phiu_rank = UmHdims.size();
platform::ForRange<DeviceContext> l_for_range(dev_ctx, phi_L.numel());
TrilTriuCompute<T> tril_computer(phi_L.data<T>(), -1, true,
LmHdims[phil_rank - 2],
LmHdims[phil_rank - 1], phi_L.data<T>());
l_for_range(tril_computer);
platform::ForRange<DeviceContext> u_for_range(dev_ctx, phi_U.numel());
TrilTriuCompute<T> triu_computer(phi_U.data<T>(), 0, false,
UmHdims[phiu_rank - 2],
UmHdims[phiu_rank - 1], phi_U.data<T>());
u_for_range(triu_computer);
Tensor_Add<DeviceContext, T>(dev_ctx, phi_L, phi_U, &phi);
psi.Resize(xdims);
psi.mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> setter;
setter(dev_ctx, &psi, static_cast<T>(0));
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
framework::Tensor Pmat;
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, &Pmat, m, k);
if (m <= n) {
if (k < n) {
framework::Tensor U_complement, U_grad_complement, phi_complement,
phi_complement_l;
Tensor_narrow<DeviceContext, T>(ctx, &U, &U_complement, 0, k, k, n);
Tensor_narrow<DeviceContext, T>(ctx, dout, &U_grad_complement, 0, k, k,
n);
framework::Tensor U_complement_mH = helper.Transpose(U_complement);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_complement_mH,
&U_complement_mH);
auto mat_dim_g =
math::CreateMatrixDescriptor(U_grad_complement.dims(), 0, false);
auto mat_dim_u =
math::CreateMatrixDescriptor(U_complement_mH.dims(), 0, false);
auto phidims = UmHdims;
phidims[UmHdims.size() - 2] = k;
phidims[UmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
phi_complement.mutable_data<T>(ctx.GetPlace());
blas.MatMul(U_grad_complement, mat_dim_g, U_complement_mH, mat_dim_u,
static_cast<T>(1), &phi_complement, static_cast<T>(0));
phi_complement_l.Resize(phidims);
phi_complement_l.mutable_data<T>(ctx.GetPlace());
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
TrilTriuCompute<T> tril_computer(phi_complement.data<T>(), -1, true, H,
W, phi_complement_l.data<T>());
x_for_range(tril_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_l, &phi);
slice_starts[0] = 0;
slice_starts[1] = k;
slice_ends[0] = k;
slice_ends[1] = n;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = n - k;
SetValueCompute_dispatch<DeviceContext, T>(
ctx, &psi, &U_grad_complement, &psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
}
framework::Tensor psi_principal, phi_mH, psi_tmp;
Tensor_Conj<DeviceContext, T>(dev_ctx, phi, &phi_mH);
phi_mH = helper.Transpose(phi_mH);
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow, phi_mH,
&psi_principal, true, false, false);
Tensor_Conj<DeviceContext, T>(dev_ctx, psi_principal, &psi_principal);
psi_principal = helper.Transpose(psi_principal);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &psi_principal,
&psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, psi, &psi_tmp,
true, false, true);
auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(psi_tmp.dims(), 0, false);
blas.MatMul(Pmat, mat_dim_p, psi_tmp, mat_dim_b, static_cast<T>(1), dx,
static_cast<T>(0));
} else {
framework::Tensor L_complement, L_grad_complement, phi_complement,
phi_complement_u;
Tensor_narrow<DeviceContext, T>(ctx, &L, &L_complement, k, m, 0, k);
Tensor_narrow<DeviceContext, T>(ctx, dout, &L_grad_complement, k, m, 0,
k);
framework::Tensor L_complement_mH = helper.Transpose(L_complement);
Tensor_Conj<DeviceContext, T>(dev_ctx, L_complement_mH, &L_complement_mH);
auto mat_dim_g =
math::CreateMatrixDescriptor(L_grad_complement.dims(), 0, false);
auto mat_dim_u =
math::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false);
auto phidims = LmHdims;
phidims[LmHdims.size() - 2] = k;
phidims[LmHdims.size() - 1] = k;
phi_complement.Resize(phidims);
phi_complement.mutable_data<T>(ctx.GetPlace());
blas.MatMul(L_complement_mH, mat_dim_u, L_grad_complement, mat_dim_g,
static_cast<T>(1), &phi_complement, static_cast<T>(0));
phi_complement_u.Resize(phidims);
phi_complement_u.mutable_data<T>(ctx.GetPlace());
const auto H = phidims[phidims.size() - 2];
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
TrilTriuCompute<T> triu_computer(phi_complement.data<T>(), 0, false, H, W,
phi_complement_u.data<T>());
x_for_range(triu_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_u, &phi);
slice_starts[0] = k;
slice_starts[1] = 0;
slice_ends[0] = m;
slice_ends[1] = k;
valuedims[xrank - 2] = m - k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &L_grad_complement,
&psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, phi,
&psi_principal, true, false, true);
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &psi_principal,
&psi, axes, &slice_starts,
&slice_ends, valuedims, xrank);
psi_tmp.Resize(psi.dims());
psi_tmp.mutable_data<T>(ctx.GetPlace());
auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(psi.dims(), 0, false);
blas.MatMul(Pmat, mat_dim_p, psi, mat_dim_b, static_cast<T>(1), &psi_tmp,
static_cast<T>(0));
psi_tmp = helper.Transpose(psi_tmp);
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow_mH, psi_tmp, &psi,
true, false, false);
*dx = helper.Transpose(psi);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -140,6 +140,9 @@ class TestLUOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['Out'])
# m = n 2D
class TestLUOp2(TestLUOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册