未验证 提交 c905a9e9 编写于 作者: L Lin Manhui 提交者: GitHub

[PHI] Move lu_unpack to phi (#44674)

* Add kernel declarations

* Copy kernel implementation code

* Transfer implementation code

* Register new kernels

* Remove old kernels

* Fix code style

* Fix bugs

* mutable_data->HostAlloc

* Transfer infermeta

* Add yaml and update python api

* Add PADDLE_WITH_HIP check

* Update unittests

* Add kernel declarations

* Copy kernel implementation code

* Transfer kernel implementation code

* Register new kernels

* Remove old kernels

* Add lu_unpack_sig

* Fix bugs

* Fix bugs

* Fix bugs

* Optimize directory structure

* Add output checks

* Update include files

* lu_impl.h->lu_kernel_impl.h

* Transfer infermeta

* Add yaml and update python api

* Add check_eager
Co-authored-by: NBobholamovic <linmanhui@baidu.com>
上级 3948c243
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/set_value_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensorArray = framework::LoDTensorArray;
template <typename DeviceContext, typename T, size_t D>
void SetValueCompute(const framework::ExecutionContext& ctx,
framework::Tensor* in,
framework::Tensor* value_tensor,
framework::Tensor* out,
const std::vector<int64_t>& axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
const std::vector<int64_t>& shape) {
std::vector<int64_t> steps = {1, 1};
std::vector<int64_t> decrease_axes = {};
std::vector<int64_t> none_axes = {};
auto dtype = framework::TransToProtoVarType(in->dtype());
auto in_dims = in->dims();
phi::funcs::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, starts, ends, &steps);
auto slice_dims =
phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps);
auto decrease_slice_dims =
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
auto place = ctx.GetPlace();
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
paddle::framework::TensorCopy(*in, place, out);
Tensor slice_tensor(framework::TransToPhiDataType(dtype)),
pad_tensor(framework::TransToPhiDataType(dtype));
slice_tensor.mutable_data<T>(slice_dims, place);
pad_tensor.mutable_data<T>(in_dims, place);
auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
auto out_e = framework::EigenTensor<T, D>::From(*out);
auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0));
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = (*starts)[i];
ends_indices[axis_index] = (*ends)[i];
strides_indices[axis_index] = steps[i];
if ((*starts)[i] ==
(*ends)[i]) { // slice is empty, data will not be changed
return;
}
}
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else {
Tensor value_t(framework::TransToPhiDataType(dtype));
auto value_dims = phi::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype);
CopyVectorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
}
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0));
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 3: Set out tensor with value_tensor
out_e.device(eigen_place) = out_e - pad_e;
}
template <typename DeviceContext, typename T>
void SetValueCompute_dispatch(const framework::ExecutionContext& ctx,
framework::Tensor* in,
framework::Tensor* value_tensor,
framework::Tensor* out,
const std::vector<int64_t>& axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
const std::vector<int64_t>& shape,
int rank) {
switch (rank) {
case 1:
SetValueCompute<DeviceContext, T, 1>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 2:
SetValueCompute<DeviceContext, T, 2>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 3:
SetValueCompute<DeviceContext, T, 3>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 4:
SetValueCompute<DeviceContext, T, 4>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 5:
SetValueCompute<DeviceContext, T, 5>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
case 6:
SetValueCompute<DeviceContext, T, 6>(
ctx, in, value_tensor, out, axes, starts, ends, shape);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
template <typename DeviceContext, typename T>
void Tensor_Conj(const DeviceContext& dev_ctx,
const framework::Tensor& tensor,
framework::Tensor* out) {
out->Resize(tensor.dims());
platform::ForRange<DeviceContext> out_for_range(dev_ctx, tensor.numel());
phi::funcs::ConjFunctor<T> out_functor(
tensor.data<T>(),
tensor.numel(),
out->mutable_data<T>(dev_ctx.GetPlace()));
out_for_range(out_functor);
}
template <typename DeviceContext, typename T>
void Tensor_Add(const DeviceContext& dev_ctx,
const framework::Tensor& src1,
const framework::Tensor& src2,
framework::Tensor* out) {
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
phi::AddRawKernel<
T,
typename paddle::framework::ConvertToPhiContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
src1,
src2,
-1,
out);
}
template <typename DeviceContext, typename T>
void Tensor_Sub(const DeviceContext& dev_ctx,
const framework::Tensor& src1,
const framework::Tensor& src2,
framework::Tensor* out) {
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
phi::SubtractRawKernel<
T,
typename paddle::framework::ConvertToPhiContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
src1,
src2,
-1,
out);
}
template <typename DeviceContext, typename T, size_t D>
void SliceCompute(const framework::ExecutionContext& ctx,
const framework::Tensor* in,
framework::Tensor* out,
const std::vector<int>& axes_int,
const std::vector<int>& starts_int,
const std::vector<int>& ends_int) {
std::vector<int64_t> axes(axes_int.begin(), axes_int.end());
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int> decrease_axis = {};
std::vector<int> infer_flags = {};
PADDLE_ENFORCE_EQ(
starts.size(),
axes.size(),
platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
auto in_dims = in->dims();
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = in_dims[axes[i]];
}
}
}
phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims = phi::funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
}
out->Resize(slice_dims);
out->mutable_data<T>(ctx.GetPlace());
auto in_t = framework::EigenTensor<T, D>::From(*in, in_dims);
auto out_t = framework::EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place,
framework::To32BitIndex(out_t),
framework::To32BitIndex(in_t),
offsets_32bit,
extents_32bit);
} else {
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
out->mutable_data<T>(ctx.GetPlace());
}
template <typename DeviceContext, typename T>
void Tensor_narrow(const framework::ExecutionContext& ctx,
const framework::Tensor* src,
framework::Tensor* out,
int row_s,
int row_e,
int col_s,
int col_e) {
auto rank = src->dims().size();
std::vector<int> axes_int = {rank - 2, rank - 1};
std::vector<int> starts_int = {row_s, col_s};
std::vector<int> ends_int = {row_e, col_e};
switch (rank) {
case 1:
SliceCompute<DeviceContext, T, 1>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
case 2:
SliceCompute<DeviceContext, T, 2>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
case 3:
SliceCompute<DeviceContext, T, 3>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
case 4:
SliceCompute<DeviceContext, T, 4>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
case 5:
SliceCompute<DeviceContext, T, 5>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
case 6:
SliceCompute<DeviceContext, T, 6>(
ctx, src, out, axes_int, starts_int, ends_int);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
template <typename DeviceContext>
void arange(const DeviceContext& dev_ctx,
framework::Tensor* tmp,
int w,
int batchsize = 1,
int h = 1) {
tmp->Resize(phi::make_ddim({batchsize * w}));
platform::CPUPlace cpu;
auto tmpdata = tmp->mutable_data<int32_t>(cpu);
for (int b = 0; b < batchsize; b++) {
for (int i = 0; i < w; i++) {
tmpdata[b * w + i] = static_cast<int32_t>(b * h + i);
}
}
}
template <typename T>
struct OneFunctor {
OneFunctor(T* output, int* idtptr, int w, int dim)
: output_(output), idtptr_(idtptr), w_(w), dim_(dim) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[w_ * idtptr_[idx] + idx % dim_] = static_cast<T>(1);
}
T* output_;
int* idtptr_;
int w_;
int dim_;
};
template <typename DeviceContext, typename T>
void LU_Unpack(const DeviceContext& dev_ctx,
const framework::Tensor* LU,
framework::Tensor* L,
framework::Tensor* U) {
const auto udims = LU->dims();
L->Resize(udims);
U->Resize(udims);
const auto H = udims[udims.size() - 2];
const auto W = udims[udims.size() - 1];
auto L_dataptr = L->mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> x_for_range(dev_ctx, LU->numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
LU->data<T>(), -1, true, H, W, L_dataptr);
x_for_range(tril_computer);
phi::funcs::TrilTriuCompute<T> triu_computer(
LU->data<T>(), 0, false, H, W, U->mutable_data<T>(dev_ctx.GetPlace()));
x_for_range(triu_computer);
// set L's diagonal 1
auto dim = std::min(H, W);
framework::Tensor rowtensor, rt_dev;
auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2));
batchsize = std::max(static_cast<int>(batchsize), 1);
arange<DeviceContext>(dev_ctx, &rowtensor, dim, batchsize, H);
auto idtptr = rowtensor.data<int32_t>();
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
framework::TensorCopy(rowtensor, dev_ctx.GetPlace(), &rt_dev);
idtptr = rt_dev.data<int32_t>();
}
platform::ForRange<DeviceContext> for_range(dev_ctx, rowtensor.numel());
OneFunctor<T> functor(L_dataptr, idtptr, W, dim);
for_range(functor);
}
template <typename DeviceContext, typename T>
void scatterpivot(const DeviceContext& dev_ctx,
T* out_data,
framework::Tensor* idlst,
int w,
int dim) {
framework::Tensor idlst_tmp;
idlst_tmp.Resize(idlst->dims());
idlst_tmp.mutable_data<int32_t>(dev_ctx.GetPlace());
framework::TensorCopy(*idlst, dev_ctx.GetPlace(), &idlst_tmp);
auto idtptr = idlst_tmp.data<int32_t>();
platform::ForRange<DeviceContext> for_range(dev_ctx, idlst_tmp.numel());
OneFunctor<T> functor(out_data, idtptr, w, dim);
for_range(functor);
}
template <typename DeviceContext, typename T>
void Unpack_Pivot(const DeviceContext& dev_ctx,
const framework::Tensor& Pivot,
framework::Tensor* P,
int h,
int w) {
auto dims = Pivot.dims();
auto Pdimvec = vectorize(dims);
auto prank = Pdimvec.size();
auto Pnum = dims[prank - 1];
framework::Tensor Pivot_cpu;
platform::CPUPlace cpu;
framework::TensorCopy(Pivot, cpu, &Pivot_cpu);
auto pdataptr = Pivot_cpu.data<int32_t>();
Pdimvec[prank - 1] = h;
Pdimvec.emplace_back(h);
auto Pdim = phi::make_ddim(Pdimvec);
P->Resize(Pdim);
auto pdata = P->mutable_data<T>(dev_ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> setter;
setter(dev_ctx, P, static_cast<T>(0));
auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1));
batchsize = std::max(static_cast<int>(batchsize), 1);
framework::Tensor idt;
for (int i = 0; i < batchsize; i++) {
arange<DeviceContext>(dev_ctx, &idt, h);
auto idlst = idt.data<int32_t>();
for (int j = 0; j < Pnum; j++) {
if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue;
auto temp = idlst[j];
idlst[j] = idlst[pdataptr[i * Pnum + j] - 1];
idlst[pdataptr[i * Pnum + j] - 1] = temp;
}
scatterpivot(dev_ctx, &(pdata[i * h * h]), &idt, h, h);
}
}
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lu_unpack_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -42,44 +46,6 @@ class LU_UnpackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LU_Unpack");
OP_INOUT_CHECK(context->HasInput("Pivots"), "Input", "Pivots", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("L"), "Output", "L", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("U"), "Output", "U", "LU_Unpack");
OP_INOUT_CHECK(context->HasOutput("Pmat"), "Output", "Pmat", "LU_Unpack");
bool unpack_ludata = context->Attrs().Get<bool>("unpack_ludata");
bool unpack_pivots = context->Attrs().Get<bool>("unpack_pivots");
auto x_dims = context->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_rank,
2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
// context->SetOutputDim("Out", x_dims);
int m = x_dims[x_rank - 1];
int n = x_dims[x_rank - 2];
int min_mn = std::min(m, n);
if (unpack_ludata) {
auto ldims = x_dims;
auto udims = x_dims;
if (m >= n) {
udims[x_rank - 2] = min_mn;
} else {
ldims[x_rank - 1] = min_mn;
}
context->SetOutputDim("U", udims);
context->SetOutputDim("L", ldims);
}
if (unpack_pivots) {
auto pdims = x_dims;
pdims[x_rank - 1] = m;
context->SetOutputDim("Pmat", pdims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -143,25 +109,6 @@ class LU_UnpackGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu_unpack");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("L")),
"Input",
"L@GRAD",
"lu_unpack");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")),
"Input",
"U@GRAD",
"lu_unpack");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -175,19 +122,21 @@ class LU_UnpackGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack,
LUUnpackInferMetaFunctor,
PD_INFER_META(phi::LUUnpackInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(lu_unpack_grad,
LUUnpackGradInferMetaFunctor,
PD_INFER_META(phi::LUUnpackGradInferMeta));
REGISTER_OPERATOR(lu_unpack,
ops::LU_UnpackOp,
ops::LU_UnpackOpMaker,
ops::LU_UnpackOpVarTypeInference,
ops::LU_UnpackOpGradMaker<paddle::framework::OpDesc>,
ops::LU_UnpackOpGradMaker<paddle::imperative::OpBase>);
ops::LU_UnpackOpGradMaker<paddle::imperative::OpBase>,
LUUnpackInferMetaFunctor);
REGISTER_OPERATOR(lu_unpack_grad,
ops::LU_UnpackGradOp,
ops::LU_UnpackGradOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(lu_unpack,
ops::LU_UnpackKernel<phi::CPUContext, float>,
ops::LU_UnpackKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(lu_unpack_grad,
ops::LU_UnpackGradKernel<phi::CPUContext, float>,
ops::LU_UnpackGradKernel<phi::CPUContext, double>);
ops::LU_UnpackGradOpVarTypeInference,
LUUnpackGradInferMetaFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lu_unpack_op.h"
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lu_unpack,
ops::LU_UnpackKernel<plat::CUDADeviceContext, float>,
ops::LU_UnpackKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lu_unpack_grad,
ops::LU_UnpackGradKernel<plat::CUDADeviceContext, float>,
ops::LU_UnpackGradKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensorArray = framework::LoDTensorArray;
template <typename DeviceContext, typename T>
class LU_UnpackKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto xin = ctx.Input<framework::Tensor>("X");
auto P = ctx.Input<framework::Tensor>("Pivots");
auto ltensor = ctx.Output<framework::Tensor>("L");
auto utensor = ctx.Output<framework::Tensor>("U");
auto ptensor = ctx.Output<framework::Tensor>("Pmat");
auto unpack_ludata = ctx.Attr<bool>("unpack_ludata");
auto unpack_pivots = ctx.Attr<bool>("unpack_pivots");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto xdims = xin->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
if (unpack_ludata) {
ltensor->mutable_data<T>(ctx.GetPlace());
utensor->mutable_data<T>(ctx.GetPlace());
framework::Tensor L, U;
LU_Unpack<DeviceContext, T>(dev_ctx, xin, &L, &U);
if (m >= n) {
framework::TensorCopy(L, ctx.GetPlace(), ltensor);
Tensor_narrow<DeviceContext, T>(ctx, &U, utensor, 0, k, 0, k);
} else {
framework::TensorCopy(U, ctx.GetPlace(), utensor);
Tensor_narrow<DeviceContext, T>(ctx, &L, ltensor, 0, k, 0, k);
}
}
if (unpack_pivots) {
ptensor->mutable_data<T>(ctx.GetPlace());
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, ptensor, m, k);
}
}
};
template <typename DeviceContext, typename T>
class LU_UnpackGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto dl = ctx.Input<framework::Tensor>(framework::GradVarName("L"));
auto du = ctx.Input<framework::Tensor>(framework::GradVarName("U"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
framework::Tensor dl_tril, du_triu;
const auto ldims = dl->dims();
dl_tril.Resize(ldims);
auto H = ldims[ldims.size() - 2];
auto W = ldims[ldims.size() - 1];
auto L_dataptr = dl_tril.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> l_for_range(dev_ctx, dl->numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
dl->data<T>(), -1, true, H, W, L_dataptr);
l_for_range(tril_computer);
const auto udims = du->dims();
du_triu.Resize(udims);
H = udims[udims.size() - 2];
W = udims[udims.size() - 1];
auto U_dataptr = du_triu.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> u_for_range(dev_ctx, du->numel());
phi::funcs::TrilTriuCompute<T> triu_computer(
du->data<T>(), 0, false, H, W, U_dataptr);
u_for_range(triu_computer);
auto xdims = dx->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
phi::funcs::SetConstant<DeviceContext, T> setter;
setter(dev_ctx, dx, static_cast<T>(0));
if (m <= n) {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
dx,
&dl_tril,
dx,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
Tensor_Add<DeviceContext, T>(dev_ctx, *dx, du_triu, dx);
} else {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<DeviceContext, T>(ctx,
dx,
&du_triu,
dx,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
Tensor_Add<DeviceContext, T>(dev_ctx, *dx, dl_tril, dx);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1443,6 +1443,16 @@
func : lu
backward : lu_grad
- api : lu_unpack
args : (Tensor x, Tensor pivots, bool unpack_ludata, bool unpack_pivots)
output : Tensor(pmat), Tensor(l), Tensor(u)
infer_meta :
func : LUUnpackInferMeta
kernel :
func : lu_unpack
data_type : x
backward : lu_unpack_grad
# masked_select
- api : masked_select
args : (Tensor x, Tensor mask)
......
......@@ -1254,6 +1254,15 @@
kernel :
func : lu_grad
- backward_api : lu_unpack_grad
forward : lu_unpack (Tensor x, Tensor pivots, bool unpack_ludata, bool unpack_pivots) -> Tensor(pmat), Tensor(l), Tensor(u)
args : (Tensor x, Tensor pivots, Tensor l, Tensor u, Tensor pmat, Tensor l_grad, Tensor u_grad, bool unpack_ludata, bool unpack_pivots)
output : Tensor(x_grad)
infer_meta :
func : LUUnpackGradInferMeta
kernel :
func : lu_unpack_grad
- backward_api : masked_select_grad
forward : masked_select (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor x, Tensor mask, Tensor out_grad)
......
......@@ -456,6 +456,24 @@ void LUGradInferMeta(const MetaTensor& x,
}
}
void LUUnpackGradInferMeta(const MetaTensor& x,
const MetaTensor& pivots,
const MetaTensor& l,
const MetaTensor& u,
const MetaTensor& pmat,
const MetaTensor& l_grad,
const MetaTensor& u_grad,
bool unpack_ludata,
bool unpack_pivots,
MetaTensor* x_grad) {
auto x_dims = x.dims();
if (x_grad) {
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}
}
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -207,6 +207,17 @@ void LUGradInferMeta(const MetaTensor& x,
bool pivot,
MetaTensor* x_grad);
void LUUnpackGradInferMeta(const MetaTensor& x,
const MetaTensor& pivots,
const MetaTensor& l,
const MetaTensor& u,
const MetaTensor& pmat,
const MetaTensor& l_grad,
const MetaTensor& u_grad,
bool unpack_ludata,
bool unpack_pivots,
MetaTensor* x_grad);
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -1486,6 +1486,52 @@ void LogLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void LUUnpackInferMeta(const MetaTensor& x,
const MetaTensor& pivots,
bool unpack_ludata,
bool unpack_pivots,
MetaTensor* pmat,
MetaTensor* l,
MetaTensor* u) {
PADDLE_ENFORCE_NOT_NULL(
pmat,
phi::errors::InvalidArgument("Output(Pmat) should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
l, phi::errors::InvalidArgument("Output(L) should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
u, phi::errors::InvalidArgument("Output(U) should not be nullptr."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank,
2,
phi::errors::InvalidArgument("The rank of input must greater than 2."));
int m = x_dims[x_rank - 1];
int n = x_dims[x_rank - 2];
int min_mn = std::min(m, n);
if (unpack_ludata) {
auto ldims = x_dims;
auto udims = x_dims;
if (m >= n) {
udims[x_rank - 2] = min_mn;
} else {
ldims[x_rank - 1] = min_mn;
}
u->set_dims(udims);
u->set_dtype(x.dtype());
l->set_dims(ldims);
l->set_dtype(x.dtype());
}
if (unpack_pivots) {
auto pdims = x_dims;
pdims[x_rank - 1] = m;
pmat->set_dims(pdims);
pmat->set_dtype(x.dtype());
}
}
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out) {
......
......@@ -225,6 +225,14 @@ void LogLossInferMeta(const MetaTensor& input,
MetaTensor* out,
MetaConfig config = MetaConfig());
void LUUnpackInferMeta(const MetaTensor& x,
const MetaTensor& pivots,
bool unpack_ludata,
bool unpack_pivots,
MetaTensor* pmat,
MetaTensor* l,
MetaTensor* u);
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_grad_kernel.h"
PD_REGISTER_KERNEL(
lu_unpack_grad, CPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_kernel.h"
PD_REGISTER_KERNEL(
lu_unpack, CPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_grad_kernel.h"
PD_REGISTER_KERNEL(
lu_unpack_grad, GPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h"
#include "paddle/phi/kernels/lu_unpack_kernel.h"
PD_REGISTER_KERNEL(
lu_unpack, GPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void LUUnpackGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& pivots,
const DenseTensor& l,
const DenseTensor& u,
const DenseTensor& pmat,
const DenseTensor& l_grad,
const DenseTensor& u_grad,
bool unpack_ludata,
bool unpack_pivots,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
DenseTensor dl_tril, du_triu;
const auto ldims = l_grad.dims();
dl_tril.Resize(ldims);
auto H = ldims[ldims.size() - 2];
auto W = ldims[ldims.size() - 1];
dev_ctx.template Alloc<T>(&dl_tril);
auto L_dataptr = dl_tril.data<T>();
phi::funcs::ForRange<Context> l_for_range(dev_ctx, l_grad.numel());
phi::funcs::TrilTriuCompute<T> tril_computer(
l_grad.data<T>(), -1, true, H, W, L_dataptr);
l_for_range(tril_computer);
const auto udims = u_grad.dims();
du_triu.Resize(udims);
H = udims[udims.size() - 2];
W = udims[udims.size() - 1];
dev_ctx.template Alloc<T>(&du_triu);
auto U_dataptr = du_triu.data<T>();
phi::funcs::ForRange<Context> u_for_range(dev_ctx, u_grad.numel());
phi::funcs::TrilTriuCompute<T> triu_computer(
u_grad.data<T>(), 0, false, H, W, U_dataptr);
u_for_range(triu_computer);
auto xdims = x_grad->dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
std::vector<int64_t> axes = {xrank - 2, xrank - 1};
std::vector<int64_t> slice_starts(2, 0);
std::vector<int64_t> slice_ends(2, 0);
auto valuedims = vectorize(xdims);
phi::funcs::SetConstant<Context, T> setter;
setter(dev_ctx, x_grad, static_cast<T>(0));
if (m <= n) {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
x_grad,
&dl_tril,
x_grad,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
Tensor_Add<Context, T>(dev_ctx, *x_grad, du_triu, x_grad);
} else {
slice_starts[0] = 0;
slice_starts[1] = 0;
slice_ends[0] = k;
slice_ends[1] = k;
valuedims[xrank - 2] = k;
valuedims[xrank - 1] = k;
SetValueCompute_dispatch<Context, T>(dev_ctx,
x_grad,
&du_triu,
x_grad,
axes,
&slice_starts,
&slice_ends,
valuedims,
xrank);
Tensor_Add<Context, T>(dev_ctx, *x_grad, dl_tril, x_grad);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void LUUnpackKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& pivots,
bool unpack_ludata,
bool unpack_pivots,
DenseTensor* pmat,
DenseTensor* l,
DenseTensor* u) {
auto xdims = x.dims();
int xrank = xdims.size();
int64_t m = xdims[xrank - 2];
int64_t n = xdims[xrank - 1];
int64_t k = std::min(m, n);
if (unpack_ludata) {
dev_ctx.template Alloc<T>(l);
dev_ctx.template Alloc<T>(u);
DenseTensor L, U;
LU_Unpack<Context, T>(dev_ctx, &x, &L, &U);
if (m >= n) {
phi::Copy(dev_ctx, L, dev_ctx.GetPlace(), false, l);
Tensor_narrow<Context, T>(dev_ctx, &U, u, 0, k, 0, k);
} else {
phi::Copy(dev_ctx, U, dev_ctx.GetPlace(), false, u);
Tensor_narrow<Context, T>(dev_ctx, &L, l, 0, k, 0, k);
}
}
if (unpack_pivots) {
dev_ctx.template Alloc<T>(pmat);
Unpack_Pivot<Context, T>(dev_ctx, pivots, pmat, m, k);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LUUnpackGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& pivots,
const DenseTensor& l,
const DenseTensor& u,
const DenseTensor& pmat,
const DenseTensor& l_grad,
const DenseTensor& u_grad,
bool unpack_ludata,
bool unpack_pivots,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LUUnpackKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& pivots,
bool unpack_ludata,
bool unpack_pivots,
DenseTensor* pmat,
DenseTensor* l,
DenseTensor* u);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LUUnpackOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lu_unpack",
{"X", "Pivots"},
{"unpack_ludata", "unpack_pivots"},
{"Pmat", "L", "U"});
}
KernelSignature LUUnpackGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("lu_unpack_grad",
{"X", "Pivots", "L", "U", "Pmat", "L@GRAD", "U@GRAD"},
{"unpack_ludata", "unpack_pivots"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lu_unpack, phi::LUUnpackOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lu_unpack_grad, phi::LUUnpackGradOpArgumentMapping);
......@@ -120,6 +120,8 @@ class TestLU_UnpackOp(OpTest):
def setUp(self):
self.op_type = "lu_unpack"
self.python_api = paddle.tensor.linalg.lu_unpack
self.python_out_sig = ["Pmat", "L", "U"]
self.config()
x = np.random.random(self.x_shape).astype(self.dtype)
if paddle.in_dynamic_mode():
......@@ -156,10 +158,10 @@ class TestLU_UnpackOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], ['L', 'U'])
self.check_grad(['X'], ['L', 'U'], check_eager=True)
# m = n
......
......@@ -2200,6 +2200,11 @@ def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None):
# one can verify : X = P @ L @ U ;
"""
if in_dygraph_mode():
P, L, U = _C_ops.final_state_lu_unpack(x, y, unpack_ludata,
unpack_pivots)
return P, L, U
if paddle.in_dynamic_mode():
P, L, U = _C_ops.lu_unpack(x, y, 'unpack_ludata', unpack_ludata,
'unpack_pivots', unpack_pivots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册