Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a1275c8b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1275c8b
编写于
12月 31, 2021
作者:
Z
zhiboniu
提交者:
GitHub
12月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lu_op backward (#38616)
上级
8d32cef8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
296 addition
and
0 deletion
+296
-0
paddle/fluid/operators/lu_op.cc
paddle/fluid/operators/lu_op.cc
+67
-0
paddle/fluid/operators/lu_op.cu
paddle/fluid/operators/lu_op.cu
+3
-0
paddle/fluid/operators/lu_op.h
paddle/fluid/operators/lu_op.h
+223
-0
python/paddle/fluid/tests/unittests/test_lu_op.py
python/paddle/fluid/tests/unittests/test_lu_op.py
+3
-0
未找到文件。
paddle/fluid/operators/lu_op.cc
浏览文件 @
a1275c8b
...
...
@@ -149,7 +149,67 @@ class LUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
class
LUOpGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
retv
)
const
override
{
retv
->
SetType
(
"lu_grad"
);
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
retv
->
SetInput
(
"Pivots"
,
this
->
Output
(
"Pivots"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetAttrMap
(
this
->
Attrs
());
}
};
class
LUGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
var_type
=
ctx
->
GetInputType
(
"X"
,
0
);
auto
data_type
=
ctx
->
GetInputDataType
(
"X"
,
0
);
ctx
->
SetOutputType
(
framework
::
GradVarName
(
"X"
),
var_type
,
framework
::
ALL_ELEMENTS
);
ctx
->
SetOutputDataType
(
framework
::
GradVarName
(
"X"
),
data_type
,
framework
::
ALL_ELEMENTS
);
}
};
class
LUGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"lu"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"lu"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Pivots"
),
"Input"
,
"Pivots"
,
"lu"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
"Out@GRAD"
,
"lu"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
dtype
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
dtype
,
ctx
.
GetPlace
());
}
};
DECLARE_INPLACE_OP_INFERER
(
LUOpInplaceInferer
,
{
"X"
,
"Out"
});
DECLARE_INPLACE_OP_INFERER
(
LUGradOpInplaceInferer
,
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)});
}
// namespace operators
}
// namespace paddle
...
...
@@ -157,6 +217,13 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OPERATOR
(
lu
,
ops
::
LUOp
,
ops
::
LUOpMaker
,
ops
::
LUOpVarTypeInference
,
ops
::
LUOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LUOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
LUOpInplaceInferer
);
REGISTER_OPERATOR
(
lu_grad
,
ops
::
LUGradOp
,
ops
::
LUGradOpVarTypeInference
,
ops
::
LUGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
lu
,
ops
::
LUKernel
<
float
>
,
ops
::
LUKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
lu_grad
,
ops
::
LUGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
LUGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/lu_op.cu
浏览文件 @
a1275c8b
...
...
@@ -152,5 +152,8 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL
(
lu
,
ops
::
LUCUDAKernel
<
float
>
,
ops
::
LUCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
lu_grad
,
ops
::
LUGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
LUGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
#endif // not PADDLE_WITH_HIP
paddle/fluid/operators/lu_op.h
浏览文件 @
a1275c8b
...
...
@@ -470,5 +470,228 @@ void Unpack_Pivot(const DeviceContext& dev_ctx, const framework::Tensor& Pivot,
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
LUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
xin
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Out"
);
auto
P
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Pivots"
);
auto
dout
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>
helper
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
xdims
=
xin
->
dims
();
int
xrank
=
xdims
.
size
();
int64_t
m
=
xdims
[
xrank
-
2
];
int64_t
n
=
xdims
[
xrank
-
1
];
int64_t
k
=
std
::
min
(
m
,
n
);
framework
::
Tensor
L
,
U
,
L_narrow
,
U_narrow
,
L_narrow_mH
,
U_narrow_mH
,
grad_narrow
;
LU_Unpack
<
DeviceContext
,
T
>
(
dev_ctx
,
out
,
&
L
,
&
U
);
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
&
L
,
&
L_narrow
,
0
,
k
,
0
,
k
);
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
&
U
,
&
U_narrow
,
0
,
k
,
0
,
k
);
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
dout
,
&
grad_narrow
,
0
,
k
,
0
,
k
);
auto
graddims
=
grad_narrow
.
dims
();
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
L_narrow
,
&
L_narrow_mH
);
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
U_narrow
,
&
U_narrow_mH
);
L_narrow_mH
=
helper
.
Transpose
(
L_narrow_mH
);
U_narrow_mH
=
helper
.
Transpose
(
U_narrow_mH
);
auto
LmHdims
=
L_narrow_mH
.
dims
();
auto
UmHdims
=
U_narrow_mH
.
dims
();
framework
::
Tensor
phi_L
,
phi_U
,
phi
,
psi
;
phi_L
.
Resize
(
LmHdims
);
phi_L
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi_U
.
Resize
(
UmHdims
);
phi_U
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
mat_dim_l
=
math
::
CreateMatrixDescriptor
(
LmHdims
,
0
,
false
);
auto
mat_dim_u
=
math
::
CreateMatrixDescriptor
(
UmHdims
,
0
,
false
);
auto
mat_dim_g
=
math
::
CreateMatrixDescriptor
(
graddims
,
0
,
false
);
blas
.
MatMul
(
L_narrow_mH
,
mat_dim_l
,
grad_narrow
,
mat_dim_g
,
static_cast
<
T
>
(
1
),
&
phi_L
,
static_cast
<
T
>
(
0
));
blas
.
MatMul
(
grad_narrow
,
mat_dim_g
,
U_narrow_mH
,
mat_dim_u
,
static_cast
<
T
>
(
1
),
&
phi_U
,
static_cast
<
T
>
(
0
));
auto
phil_rank
=
LmHdims
.
size
();
auto
phiu_rank
=
UmHdims
.
size
();
platform
::
ForRange
<
DeviceContext
>
l_for_range
(
dev_ctx
,
phi_L
.
numel
());
TrilTriuCompute
<
T
>
tril_computer
(
phi_L
.
data
<
T
>
(),
-
1
,
true
,
LmHdims
[
phil_rank
-
2
],
LmHdims
[
phil_rank
-
1
],
phi_L
.
data
<
T
>
());
l_for_range
(
tril_computer
);
platform
::
ForRange
<
DeviceContext
>
u_for_range
(
dev_ctx
,
phi_U
.
numel
());
TrilTriuCompute
<
T
>
triu_computer
(
phi_U
.
data
<
T
>
(),
0
,
false
,
UmHdims
[
phiu_rank
-
2
],
UmHdims
[
phiu_rank
-
1
],
phi_U
.
data
<
T
>
());
u_for_range
(
triu_computer
);
Tensor_Add
<
DeviceContext
,
T
>
(
dev_ctx
,
phi_L
,
phi_U
,
&
phi
);
psi
.
Resize
(
xdims
);
psi
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
setter
;
setter
(
dev_ctx
,
&
psi
,
static_cast
<
T
>
(
0
));
std
::
vector
<
int64_t
>
axes
=
{
xrank
-
2
,
xrank
-
1
};
std
::
vector
<
int64_t
>
slice_starts
(
2
,
0
);
std
::
vector
<
int64_t
>
slice_ends
(
2
,
0
);
auto
valuedims
=
vectorize
(
xdims
);
framework
::
Tensor
Pmat
;
Unpack_Pivot
<
DeviceContext
,
T
>
(
dev_ctx
,
*
P
,
&
Pmat
,
m
,
k
);
if
(
m
<=
n
)
{
if
(
k
<
n
)
{
framework
::
Tensor
U_complement
,
U_grad_complement
,
phi_complement
,
phi_complement_l
;
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
&
U
,
&
U_complement
,
0
,
k
,
k
,
n
);
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
dout
,
&
U_grad_complement
,
0
,
k
,
k
,
n
);
framework
::
Tensor
U_complement_mH
=
helper
.
Transpose
(
U_complement
);
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
U_complement_mH
,
&
U_complement_mH
);
auto
mat_dim_g
=
math
::
CreateMatrixDescriptor
(
U_grad_complement
.
dims
(),
0
,
false
);
auto
mat_dim_u
=
math
::
CreateMatrixDescriptor
(
U_complement_mH
.
dims
(),
0
,
false
);
auto
phidims
=
UmHdims
;
phidims
[
UmHdims
.
size
()
-
2
]
=
k
;
phidims
[
UmHdims
.
size
()
-
1
]
=
k
;
phi_complement
.
Resize
(
phidims
);
phi_complement
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
blas
.
MatMul
(
U_grad_complement
,
mat_dim_g
,
U_complement_mH
,
mat_dim_u
,
static_cast
<
T
>
(
1
),
&
phi_complement
,
static_cast
<
T
>
(
0
));
phi_complement_l
.
Resize
(
phidims
);
phi_complement_l
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
H
=
phidims
[
phidims
.
size
()
-
2
];
const
auto
W
=
phidims
[
phidims
.
size
()
-
1
];
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
phi_complement
.
numel
());
TrilTriuCompute
<
T
>
tril_computer
(
phi_complement
.
data
<
T
>
(),
-
1
,
true
,
H
,
W
,
phi_complement_l
.
data
<
T
>
());
x_for_range
(
tril_computer
);
Tensor_Sub
<
DeviceContext
,
T
>
(
dev_ctx
,
phi
,
phi_complement_l
,
&
phi
);
slice_starts
[
0
]
=
0
;
slice_starts
[
1
]
=
k
;
slice_ends
[
0
]
=
k
;
slice_ends
[
1
]
=
n
;
valuedims
[
xrank
-
2
]
=
k
;
valuedims
[
xrank
-
1
]
=
n
-
k
;
SetValueCompute_dispatch
<
DeviceContext
,
T
>
(
ctx
,
&
psi
,
&
U_grad_complement
,
&
psi
,
axes
,
&
slice_starts
,
&
slice_ends
,
valuedims
,
xrank
);
}
framework
::
Tensor
psi_principal
,
phi_mH
,
psi_tmp
;
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
phi
,
&
phi_mH
);
phi_mH
=
helper
.
Transpose
(
phi_mH
);
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
U_narrow
,
phi_mH
,
&
psi_principal
,
true
,
false
,
false
);
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
psi_principal
,
&
psi_principal
);
psi_principal
=
helper
.
Transpose
(
psi_principal
);
slice_starts
[
0
]
=
0
;
slice_starts
[
1
]
=
0
;
slice_ends
[
0
]
=
k
;
slice_ends
[
1
]
=
k
;
valuedims
[
xrank
-
2
]
=
k
;
valuedims
[
xrank
-
1
]
=
k
;
SetValueCompute_dispatch
<
DeviceContext
,
T
>
(
ctx
,
&
psi
,
&
psi_principal
,
&
psi
,
axes
,
&
slice_starts
,
&
slice_ends
,
valuedims
,
xrank
);
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
L_narrow_mH
,
psi
,
&
psi_tmp
,
true
,
false
,
true
);
auto
mat_dim_p
=
math
::
CreateMatrixDescriptor
(
Pmat
.
dims
(),
0
,
false
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
psi_tmp
.
dims
(),
0
,
false
);
blas
.
MatMul
(
Pmat
,
mat_dim_p
,
psi_tmp
,
mat_dim_b
,
static_cast
<
T
>
(
1
),
dx
,
static_cast
<
T
>
(
0
));
}
else
{
framework
::
Tensor
L_complement
,
L_grad_complement
,
phi_complement
,
phi_complement_u
;
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
&
L
,
&
L_complement
,
k
,
m
,
0
,
k
);
Tensor_narrow
<
DeviceContext
,
T
>
(
ctx
,
dout
,
&
L_grad_complement
,
k
,
m
,
0
,
k
);
framework
::
Tensor
L_complement_mH
=
helper
.
Transpose
(
L_complement
);
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
L_complement_mH
,
&
L_complement_mH
);
auto
mat_dim_g
=
math
::
CreateMatrixDescriptor
(
L_grad_complement
.
dims
(),
0
,
false
);
auto
mat_dim_u
=
math
::
CreateMatrixDescriptor
(
L_complement_mH
.
dims
(),
0
,
false
);
auto
phidims
=
LmHdims
;
phidims
[
LmHdims
.
size
()
-
2
]
=
k
;
phidims
[
LmHdims
.
size
()
-
1
]
=
k
;
phi_complement
.
Resize
(
phidims
);
phi_complement
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
blas
.
MatMul
(
L_complement_mH
,
mat_dim_u
,
L_grad_complement
,
mat_dim_g
,
static_cast
<
T
>
(
1
),
&
phi_complement
,
static_cast
<
T
>
(
0
));
phi_complement_u
.
Resize
(
phidims
);
phi_complement_u
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
H
=
phidims
[
phidims
.
size
()
-
2
];
const
auto
W
=
phidims
[
phidims
.
size
()
-
1
];
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
phi_complement
.
numel
());
TrilTriuCompute
<
T
>
triu_computer
(
phi_complement
.
data
<
T
>
(),
0
,
false
,
H
,
W
,
phi_complement_u
.
data
<
T
>
());
x_for_range
(
triu_computer
);
Tensor_Sub
<
DeviceContext
,
T
>
(
dev_ctx
,
phi
,
phi_complement_u
,
&
phi
);
slice_starts
[
0
]
=
k
;
slice_starts
[
1
]
=
0
;
slice_ends
[
0
]
=
m
;
slice_ends
[
1
]
=
k
;
valuedims
[
xrank
-
2
]
=
m
-
k
;
valuedims
[
xrank
-
1
]
=
k
;
SetValueCompute_dispatch
<
DeviceContext
,
T
>
(
ctx
,
&
psi
,
&
L_grad_complement
,
&
psi
,
axes
,
&
slice_starts
,
&
slice_ends
,
valuedims
,
xrank
);
framework
::
Tensor
psi_principal
,
phi_mH
,
psi_tmp
,
U_narrow_mH
;
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
L_narrow_mH
,
phi
,
&
psi_principal
,
true
,
false
,
true
);
slice_starts
[
0
]
=
0
;
slice_starts
[
1
]
=
0
;
slice_ends
[
0
]
=
k
;
slice_ends
[
1
]
=
k
;
valuedims
[
xrank
-
2
]
=
k
;
valuedims
[
xrank
-
1
]
=
k
;
SetValueCompute_dispatch
<
DeviceContext
,
T
>
(
ctx
,
&
psi
,
&
psi_principal
,
&
psi
,
axes
,
&
slice_starts
,
&
slice_ends
,
valuedims
,
xrank
);
psi_tmp
.
Resize
(
psi
.
dims
());
psi_tmp
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
mat_dim_p
=
math
::
CreateMatrixDescriptor
(
Pmat
.
dims
(),
0
,
false
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
psi
.
dims
(),
0
,
false
);
blas
.
MatMul
(
Pmat
,
mat_dim_p
,
psi
,
mat_dim_b
,
static_cast
<
T
>
(
1
),
&
psi_tmp
,
static_cast
<
T
>
(
0
));
psi_tmp
=
helper
.
Transpose
(
psi_tmp
);
Tensor_Conj
<
DeviceContext
,
T
>
(
dev_ctx
,
U_narrow
,
&
U_narrow_mH
);
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
U_narrow_mH
,
psi_tmp
,
&
psi
,
true
,
false
,
false
);
*
dx
=
helper
.
Transpose
(
psi
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_lu_op.py
浏览文件 @
a1275c8b
...
...
@@ -140,6 +140,9 @@ class TestLUOp(OpTest):
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
[
'Out'
])
# m = n 2D
class
TestLUOp2
(
TestLUOp
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录