Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
849442ef
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
849442ef
编写于
9月 27, 2020
作者:
F
ForFishes
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the speed&memory of matmul
上级
a85592bc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
280 addition
and
100 deletion
+280
-100
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+280
-100
未找到文件。
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
849442ef
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -350,20 +351,158 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
}
};
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static
framework
::
Tensor
FoldInitDims
(
const
framework
::
Tensor
&
input
)
{
auto
output
=
input
;
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
==
3
)
{
output
.
Resize
({
in_dims
[
0
]
*
in_dims
[
1
],
in_dims
[
2
]});
}
return
output
;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template
<
typename
DeviceContext
,
typename
T
>
static
framework
::
Tensor
FoldHeadAndLastDims
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
!=
3
)
{
return
input
;
}
framework
::
Tensor
output
;
output
.
Resize
({
in_dims
[
1
],
in_dims
[
0
],
in_dims
[
2
]});
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
vector
<
int
>
axis
=
{
1
,
0
,
2
};
math
::
Transpose
<
DeviceContext
,
T
,
3
>
trans
;
trans
(
context
,
input
,
&
output
,
axis
);
output
.
Resize
({
in_dims
[
1
],
in_dims
[
0
]
*
in_dims
[
2
]});
return
output
;
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static
framework
::
DDim
RowMatrixFromVector
(
const
framework
::
DDim
&
x_dim
)
{
if
(
x_dim
.
size
()
>
1
)
{
return
x_dim
;
}
return
framework
::
make_ddim
({
1
,
x_dim
[
0
]});
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static
framework
::
DDim
ColumnMatrixFromVector
(
const
framework
::
DDim
&
y_dim
)
{
if
(
y_dim
.
size
()
>
1
)
{
return
y_dim
;
}
return
framework
::
make_ddim
({
y_dim
[
0
],
1
});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static
void
ReshapeTensorIntoMatrixSequence
(
framework
::
Tensor
*
x
,
const
math
::
MatDescriptor
&
descriptor
)
{
int64_t
h
,
w
;
h
=
descriptor
.
height_
;
w
=
descriptor
.
width_
;
if
(
descriptor
.
trans_
)
{
std
::
swap
(
w
,
h
);
}
if
(
descriptor
.
batch_size_
)
{
x
->
Resize
({
descriptor
.
batch_size_
,
h
,
w
});
}
else
{
x
->
Resize
({
h
,
w
});
}
}
static
void
ReshapeXYOutIntoMatrixSequence
(
framework
::
Tensor
*
x
,
framework
::
Tensor
*
y
,
framework
::
Tensor
*
out
,
bool
trans_x
,
bool
trans_y
)
{
auto
x_dim
=
RowMatrixFromVector
(
x
->
dims
());
auto
y_dim
=
ColumnMatrixFromVector
(
y
->
dims
());
auto
mat_dim_x
=
math
::
CreateMatrixDescriptor
(
x_dim
,
0
,
trans_x
);
auto
mat_dim_y
=
math
::
CreateMatrixDescriptor
(
y_dim
,
0
,
trans_y
);
if
(
mat_dim_x
.
batch_size_
==
0
&&
mat_dim_y
.
batch_size_
==
0
)
{
out
->
Resize
({
mat_dim_x
.
height_
,
mat_dim_y
.
width_
});
}
else
{
out
->
Resize
({
std
::
max
(
mat_dim_x
.
batch_size_
,
mat_dim_y
.
batch_size_
),
mat_dim_x
.
height_
,
mat_dim_y
.
width_
});
}
ReshapeTensorIntoMatrixSequence
(
x
,
mat_dim_x
);
ReshapeTensorIntoMatrixSequence
(
y
,
mat_dim_y
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
MatMulV2GradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
MatMul
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
a
.
dims
(),
0
,
trans_a
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
if
(
a
.
dims
().
size
()
==
3
&&
b
.
dims
().
size
()
<=
2
)
{
// the transpose_X must be false, if is true, the transpose cost much time
if
(
!
trans_a
)
{
mat_dim_a
.
height_
*=
mat_dim_a
.
batch_size_
;
mat_dim_a
.
batch_size_
=
0
;
}
}
blas
.
MatMul
(
a
,
mat_dim_a
,
b
,
mat_dim_b
,
static_cast
<
T
>
(
1
),
out
,
static_cast
<
T
>
(
0
));
}
void
CalcInputGrad
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
bool
is_fold_init_dims_a
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
bool
is_fold_init_dims_b
,
framework
::
Tensor
*
out
)
const
{
if
(
out
==
nullptr
)
return
;
bool
need_combine
=
(
a
.
dims
().
size
()
==
3
||
b
.
dims
().
size
()
==
3
)
&&
out
->
dims
().
size
()
==
2
;
if
(
!
need_combine
)
{
MatMul
(
context
,
a
,
trans_a
,
b
,
trans_b
,
out
);
}
else
{
auto
&
ctx
=
context
.
template
device_context
<
DeviceContext
>();
MatMul
(
context
,
is_fold_init_dims_a
?
FoldInitDims
(
a
)
:
FoldHeadAndLastDims
<
DeviceContext
,
T
>
(
ctx
,
a
),
trans_a
,
is_fold_init_dims_b
?
FoldInitDims
(
b
)
:
FoldHeadAndLastDims
<
DeviceContext
,
T
>
(
ctx
,
b
),
trans_b
,
out
);
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
bool
trans_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
// auto* X = ctx.Input<Tensor>("X");
// auto* Y = ctx.Input<Tensor>("Y");
// auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
bool
transpose_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
transpose_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
auto
x
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
y
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
dout
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
// get dims
std
::
vector
<
std
::
int64_t
>
x_dims
=
vectorize
(
X
->
dims
());
std
::
vector
<
std
::
int64_t
>
y_dims
=
vectorize
(
Y
->
dims
());
std
::
vector
<
std
::
int64_t
>
dout_dims
=
vectorize
(
d
Out
->
dims
());
std
::
vector
<
std
::
int64_t
>
x_dims
=
vectorize
(
x
.
dims
());
std
::
vector
<
std
::
int64_t
>
y_dims
=
vectorize
(
y
.
dims
());
std
::
vector
<
std
::
int64_t
>
dout_dims
=
vectorize
(
d
out
.
dims
());
int
x_ndim
=
x_dims
.
size
();
int
y_ndim
=
y_dims
.
size
();
...
...
@@ -372,115 +511,156 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
// x's or y's dim = 1
//
Case1 :
x's or y's dim = 1
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
if
(
dx
)
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dy
)
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
d
Out
->
numel
()
==
1
)
{
DotGradFunction
<
DeviceContext
,
T
>
(
X
,
Y
,
dO
ut
,
dx
,
dy
,
ctx
);
if
(
d
out
.
numel
()
==
1
)
{
DotGradFunction
<
DeviceContext
,
T
>
(
&
x
,
&
y
,
&
do
ut
,
dx
,
dy
,
ctx
);
return
;
}
}
// It is very tricky. For this broadcast, currently using the reduce sum to
// get gradient.
if
(
x_ndim
==
1
)
{
x_dims
.
insert
(
x_dims
.
begin
()
+
0
,
1
);
x_ndim
+=
1
;
if
(
trans_x
)
dout_dims
.
push_back
(
1
);
else
dout_dims
.
insert
(
dout_dims
.
begin
()
+
ndim
-
1
,
1
);
ndim
+=
1
;
}
if
(
y_ndim
==
1
)
{
y_dims
.
push_back
(
1
);
y_ndim
+=
1
;
if
(
trans_y
)
dout_dims
.
insert
(
dout_dims
.
begin
()
+
ndim
-
1
,
1
)
;
else
dout_dims
.
push_back
(
1
);
ndim
+=
1
;
bool
is_broadcast
=
true
;
if
(
x_ndim
<=
2
||
y_ndim
<=
2
)
{
is_broadcast
=
false
;
}
else
if
(
x_ndim
!=
y_ndim
)
{
is_broadcast
=
true
;
}
else
{
is_broadcast
=
!
std
::
equal
(
x_dims
.
cbegin
(),
x_dims
.
cbegin
()
+
x_ndim
-
2
,
y_dims
.
cbegin
())
;
}
// the normal case
Tensor
dx_help
,
dy_help
;
if
(
trans_x
)
{
if
(
trans_y
)
{
// X'Y': dA = Y'G', dB = G'X'
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
Y
,
dOut
,
y_dims
,
dout_dims
,
&
dx_help
,
true
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
dOut
,
X
,
dout_dims
,
x_dims
,
&
dy_help
,
true
,
true
,
ctx
);
VLOG
(
0
)
<<
"is_broadcast: "
<<
is_broadcast
;
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if
(
!
is_broadcast
)
{
ReshapeXYOutIntoMatrixSequence
(
&
x
,
&
y
,
&
dout
,
transpose_x
,
transpose_y
);
framework
::
DDim
dx_dims
;
if
(
dx
)
{
dx_dims
=
dx
->
dims
();
if
(
dx_dims
!=
x
.
dims
())
{
dx
->
Resize
(
x
.
dims
());
}
}
framework
::
DDim
dy_dims
;
if
(
dy
)
{
dy_dims
=
dy
->
dims
();
if
(
dy_dims
!=
y
.
dims
())
{
dy
->
Resize
(
y
.
dims
());
}
}
if
(
transpose_x
&&
transpose_y
)
{
CalcInputGrad
(
ctx
,
y
,
true
,
true
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
dout
,
true
,
true
,
x
,
true
,
false
,
dy
);
}
else
if
(
transpose_x
)
{
CalcInputGrad
(
ctx
,
y
,
false
,
false
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
x
,
false
,
false
,
dout
,
false
,
true
,
dy
);
}
else
if
(
transpose_y
)
{
CalcInputGrad
(
ctx
,
dout
,
false
,
false
,
y
,
false
,
true
,
dx
);
CalcInputGrad
(
ctx
,
dout
,
true
,
true
,
x
,
false
,
true
,
dy
);
}
else
{
// X'Y: dX = YG', dY = XG
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
Y
,
dOut
,
y_dims
,
dout_dims
,
&
dx_help
,
false
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
X
,
dOut
,
x_dims
,
dout_dims
,
&
dy_help
,
false
,
false
,
ctx
);
CalcInputGrad
(
ctx
,
dout
,
false
,
false
,
y
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
x
,
true
,
true
,
dout
,
false
,
true
,
dy
);
}
if
(
dx
)
{
if
(
dx_dims
!=
x
.
dims
())
{
dx
->
Resize
(
dx_dims
);
}
}
if
(
dy
)
{
if
(
dy_dims
!=
y
.
dims
())
{
dy
->
Resize
(
dy_dims
);
}
}
}
else
{
if
(
trans_y
)
{
// XY': dX = GY, dY = G'X
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
dOut
,
Y
,
dout_dims
,
y_dims
,
&
dx_help
,
false
,
false
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
dOut
,
X
,
dout_dims
,
x_dims
,
&
dy_help
,
true
,
false
,
ctx
);
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG
(
3
)
<<
"It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality"
;
Tensor
dx_help
,
dy_help
;
if
(
transpose_x
)
{
if
(
transpose_y
)
{
// X'Y': dA = Y'G', dB = G'X'
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
y
,
&
dout
,
y_dims
,
dout_dims
,
&
dx_help
,
true
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
dout
,
&
x
,
dout_dims
,
x_dims
,
&
dy_help
,
true
,
true
,
ctx
);
}
else
{
// X'Y: dX = YG', dY = XG
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
y
,
&
dout
,
y_dims
,
dout_dims
,
&
dx_help
,
false
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
x
,
&
dout
,
x_dims
,
dout_dims
,
&
dy_help
,
false
,
false
,
ctx
);
}
}
else
{
// XY: dX = GY', dY = X'G
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
dOut
,
Y
,
dout_dims
,
y_dims
,
&
dx_help
,
false
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
X
,
dOut
,
x_dims
,
dout_dims
,
&
dy_help
,
true
,
false
,
ctx
);
if
(
transpose_y
)
{
// XY': dX = GY, dY = G'X
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
dout
,
&
y
,
dout_dims
,
y_dims
,
&
dx_help
,
false
,
false
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
dout
,
&
x
,
dout_dims
,
x_dims
,
&
dy_help
,
true
,
false
,
ctx
);
}
else
{
// XY: dX = GY', dY = X'G
if
(
dx
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
dout
,
&
y
,
dout_dims
,
y_dims
,
&
dx_help
,
false
,
true
,
ctx
);
if
(
dy
)
MatMulFunction
<
DeviceContext
,
T
>
(
&
x
,
&
dout
,
x_dims
,
dout_dims
,
&
dy_help
,
true
,
false
,
ctx
);
}
}
}
// get help dims
const
std
::
vector
<
std
::
int64_t
>
dx_help_dims
=
vectorize
(
dx_help
.
dims
());
const
std
::
vector
<
std
::
int64_t
>
dy_help_dims
=
vectorize
(
dy_help
.
dims
());
std
::
vector
<
std
::
int64_t
>
dx_broadcast_dims
(
ndim
);
std
::
vector
<
std
::
int64_t
>
dy_broadcast_dims
(
ndim
);
std
::
fill
(
dx_broadcast_dims
.
data
(),
dx_broadcast_dims
.
data
()
+
ndim
-
x_ndim
,
1
);
std
::
fill
(
dy_broadcast_dims
.
data
(),
dy_broadcast_dims
.
data
()
+
ndim
-
y_ndim
,
1
);
std
::
copy
(
x_dims
.
data
(),
x_dims
.
data
()
+
x_ndim
,
dx_broadcast_dims
.
data
()
+
ndim
-
x_ndim
);
std
::
copy
(
y_dims
.
data
(),
y_dims
.
data
()
+
y_ndim
,
dy_broadcast_dims
.
data
()
+
ndim
-
y_ndim
);
std
::
vector
<
int
>
dx_reduce_dims
;
std
::
vector
<
int
>
dy_reduce_dims
;
for
(
int
idx
=
0
;
idx
<=
ndim
-
3
;
idx
++
)
{
if
(
dx_help_dims
[
idx
]
!=
1
&&
dx_broadcast_dims
[
idx
]
==
1
)
{
dx_reduce_dims
.
push_back
(
idx
);
// get help dims
const
std
::
vector
<
std
::
int64_t
>
dx_help_dims
=
vectorize
(
dx_help
.
dims
());
const
std
::
vector
<
std
::
int64_t
>
dy_help_dims
=
vectorize
(
dy_help
.
dims
());
std
::
vector
<
std
::
int64_t
>
dx_broadcast_dims
(
ndim
);
std
::
vector
<
std
::
int64_t
>
dy_broadcast_dims
(
ndim
);
std
::
fill
(
dx_broadcast_dims
.
data
(),
dx_broadcast_dims
.
data
()
+
ndim
-
x_ndim
,
1
);
std
::
fill
(
dy_broadcast_dims
.
data
(),
dy_broadcast_dims
.
data
()
+
ndim
-
y_ndim
,
1
);
std
::
copy
(
x_dims
.
data
(),
x_dims
.
data
()
+
x_ndim
,
dx_broadcast_dims
.
data
()
+
ndim
-
x_ndim
);
std
::
copy
(
y_dims
.
data
(),
y_dims
.
data
()
+
y_ndim
,
dy_broadcast_dims
.
data
()
+
ndim
-
y_ndim
);
std
::
vector
<
int
>
dx_reduce_dims
;
std
::
vector
<
int
>
dy_reduce_dims
;
for
(
int
idx
=
0
;
idx
<=
ndim
-
3
;
idx
++
)
{
if
(
dx_help_dims
[
idx
]
!=
1
&&
dx_broadcast_dims
[
idx
]
==
1
)
{
dx_reduce_dims
.
push_back
(
idx
);
}
if
(
dy_help_dims
[
idx
]
!=
1
&&
dy_broadcast_dims
[
idx
]
==
1
)
{
dy_reduce_dims
.
push_back
(
idx
);
}
}
// reduce sum to get grad by ReduceSum
if
(
dx
)
{
dx
->
Resize
(
dx_help
.
dims
());
ReduceSumForMatmulGrad
<
DeviceContext
,
T
>
(
&
dx_help
,
dx
,
dx_reduce_dims
,
ctx
);
dx
->
Resize
(
x
.
dims
());
}
if
(
dy_help_dims
[
idx
]
!=
1
&&
dy_broadcast_dims
[
idx
]
==
1
)
{
dy_reduce_dims
.
push_back
(
idx
);
if
(
dy
)
{
dy
->
Resize
(
dy_help
.
dims
());
ReduceSumForMatmulGrad
<
DeviceContext
,
T
>
(
&
dy_help
,
dy
,
dy_reduce_dims
,
ctx
);
dy
->
Resize
(
y
.
dims
());
}
}
// reduce sum to get grad by ReduceSum
if
(
dx
)
{
dx
->
Resize
(
dx_help
.
dims
());
ReduceSumForMatmulGrad
<
DeviceContext
,
T
>
(
&
dx_help
,
dx
,
dx_reduce_dims
,
ctx
);
dx
->
Resize
(
X
->
dims
());
}
if
(
dy
)
{
dy
->
Resize
(
dy_help
.
dims
());
ReduceSumForMatmulGrad
<
DeviceContext
,
T
>
(
&
dy_help
,
dy
,
dy_reduce_dims
,
ctx
);
dy
->
Resize
(
Y
->
dims
());
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录