Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0ed26e12
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0ed26e12
编写于
3月 10, 2022
作者:
R
root
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support weight transpose
上级
60b86b2f
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
212 addition
and
68 deletion
+212
-68
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc
paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc
+34
-9
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
+7
-2
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
+138
-48
python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py
...dle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py
+17
-3
python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py
...fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py
+1
-0
python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py
...ddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py
+1
-0
python/paddle/nn/functional/common.py
python/paddle/nn/functional/common.py
+3
-3
python/paddle/nn/layer/common.py
python/paddle/nn/layer/common.py
+10
-3
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
0ed26e12
...
@@ -118,6 +118,7 @@ message BuildStrategy {
...
@@ -118,6 +118,7 @@ message BuildStrategy {
optional
bool
fix_op_run_order
=
13
[
default
=
false
];
optional
bool
fix_op_run_order
=
13
[
default
=
false
];
optional
bool
allow_cuda_graph_capture
=
14
[
default
=
false
];
optional
bool
allow_cuda_graph_capture
=
14
[
default
=
false
];
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
bool
fuse_gemm_epilogue
=
16
[
default
=
false
];
}
}
message
ExecutionStrategy
{
message
ExecutionStrategy
{
...
...
paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc
浏览文件 @
0ed26e12
...
@@ -18,18 +18,28 @@
...
@@ -18,18 +18,28 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
PADDLE_DEFINE_EXPORTED_bool
(
enable_gemm_fwd_fusion
,
true
,
""
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
void
GetTransposeAttrsFromOp
(
const
OpDesc
&
op
,
bool
*
trans_x
,
bool
*
trans_y
)
{
*
trans_x
=
BOOST_GET_CONST
(
bool
,
op
.
GetAttr
(
"trans_x"
));
*
trans_y
=
BOOST_GET_CONST
(
bool
,
op
.
GetAttr
(
"trans_y"
));
}
void
FuseGemmEpiloguePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
FuseGemmEpiloguePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
EpiloguePassActivationCache
cache
;
EpiloguePassActivationCache
cache
;
if
(
FLAGS_enable_gemm_fwd_fusion
)
{
graph
=
FuseLinearActFwd
(
graph
,
{
"relu"
,
"gelu"
},
false
,
false
,
&
cache
);
graph
=
FuseLinearActFwd
(
graph
,
{
"relu"
,
"gelu"
},
false
,
false
,
&
cache
);
graph
=
FuseLinearActFwd
(
graph
,
{
"relu"
},
true
,
true
,
&
cache
);
graph
=
FuseLinearActFwd
(
graph
,
{
"relu"
},
true
,
true
,
&
cache
);
graph
=
FuseLinearActFwd
(
graph
,
{
"gelu"
},
true
,
false
,
&
cache
);
graph
=
FuseLinearActFwd
(
graph
,
{
"gelu"
},
true
,
false
,
&
cache
);
graph
=
FuseLinearFwd
(
graph
,
false
);
graph
=
FuseLinearFwd
(
graph
,
false
);
graph
=
FuseLinearFwd
(
graph
,
true
);
graph
=
FuseLinearFwd
(
graph
,
true
);
}
graph
=
FuseLinearActBwd
(
graph
,
{
"relu_grad"
},
true
,
&
cache
);
graph
=
FuseLinearActBwd
(
graph
,
{
"relu_grad"
},
true
,
&
cache
);
graph
=
FuseLinearActBwd
(
graph
,
{
"gelu_grad"
},
false
,
&
cache
);
graph
=
FuseLinearActBwd
(
graph
,
{
"gelu_grad"
},
false
,
&
cache
);
graph
=
FuseLinearBwd
(
graph
,
false
);
graph
=
FuseLinearBwd
(
graph
,
false
);
...
@@ -75,6 +85,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
...
@@ -75,6 +85,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
if
(
!
IsGemmFromLinear_
(
matmul_x_shape
,
matmul_w_shape
,
matmul_op_desc
))
if
(
!
IsGemmFromLinear_
(
matmul_x_shape
,
matmul_w_shape
,
matmul_op_desc
))
return
;
return
;
bool
trans_x
,
trans_y
;
GetTransposeAttrsFromOp
(
*
matmul_op_desc
,
&
trans_x
,
&
trans_y
);
OpDesc
fused_gemm_epilogue_op_desc
(
matmul_op
->
Op
()
->
Block
());
OpDesc
fused_gemm_epilogue_op_desc
(
matmul_op
->
Op
()
->
Block
());
std
::
string
activation
=
"none"
;
std
::
string
activation
=
"none"
;
fused_gemm_epilogue_op_desc
.
SetType
(
"fused_gemm_epilogue"
);
fused_gemm_epilogue_op_desc
.
SetType
(
"fused_gemm_epilogue"
);
...
@@ -85,6 +98,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
...
@@ -85,6 +98,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
fused_gemm_epilogue_op_desc
.
SetAttr
(
"activation"
,
activation
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"activation"
,
activation
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"op_role"
,
fused_gemm_epilogue_op_desc
.
SetAttr
(
"op_role"
,
matmul_op_desc
->
GetAttr
(
"op_role"
));
matmul_op_desc
->
GetAttr
(
"op_role"
));
fused_gemm_epilogue_op_desc
.
SetAttr
(
"trans_x"
,
trans_x
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"trans_y"
,
trans_y
);
auto
gemm_epilogue_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_op_desc
);
auto
gemm_epilogue_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_op_desc
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
gemm_epilogue_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
gemm_epilogue_node
);
...
@@ -154,6 +169,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
...
@@ -154,6 +169,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
auto
activation
=
act_op
->
Op
()
->
Type
();
auto
activation
=
act_op
->
Op
()
->
Type
();
bool
trans_x
,
trans_y
;
GetTransposeAttrsFromOp
(
*
matmul_op_desc
,
&
trans_x
,
&
trans_y
);
OpDesc
fused_gemm_epilogue_op_desc
(
matmul_op
->
Op
()
->
Block
());
OpDesc
fused_gemm_epilogue_op_desc
(
matmul_op
->
Op
()
->
Block
());
fused_gemm_epilogue_op_desc
.
SetType
(
"fused_gemm_epilogue"
);
fused_gemm_epilogue_op_desc
.
SetType
(
"fused_gemm_epilogue"
);
fused_gemm_epilogue_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
fused_gemm_epilogue_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
...
@@ -163,6 +181,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
...
@@ -163,6 +181,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
fused_gemm_epilogue_op_desc
.
SetAttr
(
"activation"
,
activation
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"activation"
,
activation
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"op_role"
,
fused_gemm_epilogue_op_desc
.
SetAttr
(
"op_role"
,
matmul_op_desc
->
GetAttr
(
"op_role"
));
matmul_op_desc
->
GetAttr
(
"op_role"
));
fused_gemm_epilogue_op_desc
.
SetAttr
(
"trans_x"
,
trans_x
);
fused_gemm_epilogue_op_desc
.
SetAttr
(
"trans_y"
,
trans_y
);
auto
gemm_epilogue_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_op_desc
);
auto
gemm_epilogue_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_op_desc
);
...
@@ -274,6 +294,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
...
@@ -274,6 +294,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
matmul_grad_op_desc
))
matmul_grad_op_desc
))
return
;
return
;
bool
trans_x
,
trans_y
;
GetTransposeAttrsFromOp
(
*
matmul_grad_op_desc
,
&
trans_x
,
&
trans_y
);
OpDesc
fused_gemm_epilogue_grad_op_desc
(
ele_add_grad_op
->
Op
()
->
Block
());
OpDesc
fused_gemm_epilogue_grad_op_desc
(
ele_add_grad_op
->
Op
()
->
Block
());
std
::
string
activation_grad
=
"none"
;
std
::
string
activation_grad
=
"none"
;
fused_gemm_epilogue_grad_op_desc
.
SetType
(
"fused_gemm_epilogue_grad"
);
fused_gemm_epilogue_grad_op_desc
.
SetType
(
"fused_gemm_epilogue_grad"
);
...
@@ -292,6 +315,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
...
@@ -292,6 +315,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
activation_grad
);
activation_grad
);
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"op_role"
,
matmul_grad_op_desc
->
GetAttr
(
"op_role"
));
"op_role"
,
matmul_grad_op_desc
->
GetAttr
(
"op_role"
));
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"trans_x"
,
trans_x
);
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"trans_y"
,
trans_y
);
auto
gemm_epilogue_grad_node
=
auto
gemm_epilogue_grad_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_grad_op_desc
);
g
->
CreateOpNode
(
&
fused_gemm_epilogue_grad_op_desc
);
...
@@ -394,6 +419,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
...
@@ -394,6 +419,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
auto
activation_grad
=
act_grad_op
->
Op
()
->
Type
();
auto
activation_grad
=
act_grad_op
->
Op
()
->
Type
();
bool
trans_x
,
trans_y
;
GetTransposeAttrsFromOp
(
*
matmul_grad_op_desc
,
&
trans_x
,
&
trans_y
);
OpDesc
fused_gemm_epilogue_grad_op_desc
(
ele_add_grad_op
->
Op
()
->
Block
());
OpDesc
fused_gemm_epilogue_grad_op_desc
(
ele_add_grad_op
->
Op
()
->
Block
());
fused_gemm_epilogue_grad_op_desc
.
SetType
(
"fused_gemm_epilogue_grad"
);
fused_gemm_epilogue_grad_op_desc
.
SetType
(
"fused_gemm_epilogue_grad"
);
fused_gemm_epilogue_grad_op_desc
.
SetInput
(
"DOut"
,
fused_gemm_epilogue_grad_op_desc
.
SetInput
(
"DOut"
,
...
@@ -410,6 +437,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
...
@@ -410,6 +437,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
activation_grad
);
activation_grad
);
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"op_role"
,
matmul_grad_op_desc
->
GetAttr
(
"op_role"
));
"op_role"
,
matmul_grad_op_desc
->
GetAttr
(
"op_role"
));
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"trans_x"
,
trans_x
);
fused_gemm_epilogue_grad_op_desc
.
SetAttr
(
"trans_y"
,
trans_y
);
auto
gemm_epilogue_grad_node
=
auto
gemm_epilogue_grad_node
=
g
->
CreateOpNode
(
&
fused_gemm_epilogue_grad_op_desc
);
g
->
CreateOpNode
(
&
fused_gemm_epilogue_grad_op_desc
);
...
@@ -456,10 +485,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_(
...
@@ -456,10 +485,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_(
if
(
tmp_vec
.
size
()
>
0
)
return
false
;
if
(
tmp_vec
.
size
()
>
0
)
return
false
;
}
}
}
}
if
(
BOOST_GET_CONST
(
bool
,
matmul_v2_op
->
GetAttr
(
"trans_x"
))
||
BOOST_GET_CONST
(
bool
,
matmul_v2_op
->
GetAttr
(
"trans_y"
)))
return
false
;
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
浏览文件 @
0ed26e12
...
@@ -208,6 +208,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
...
@@ -208,6 +208,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
trans_x
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_x"
);
auto
trans_y
=
ctx
->
Attrs
().
Get
<
bool
>
(
"trans_y"
);
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
dout_dims
.
size
(),
2
,
dout_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -242,14 +245,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
...
@@ -242,14 +245,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto
x_mat_dims
=
phi
::
flatten_to_2d
(
x_dims
,
x_dims
.
size
()
-
1
);
auto
x_mat_dims
=
phi
::
flatten_to_2d
(
x_dims
,
x_dims
.
size
()
-
1
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dout_mat_dims
[
1
],
y_dims
[
1
],
dout_mat_dims
[
1
],
trans_y
?
y_dims
[
0
]
:
y_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The last dimension of DOut should be equal with Y's last"
"The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d]."
,
"dimension. But received DOut[-1] = [%d], Y[1] = [%d]."
,
dout_mat_dims
[
1
],
y_dims
[
1
]));
dout_mat_dims
[
1
],
y_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dout_mat_dims
[
0
],
x_mat_dims
[
0
],
dout_mat_dims
[
0
],
trans_x
?
x_mat_dims
[
1
]
:
x_mat_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The first dimension of DOut should be equal with X's first"
"The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d]."
,
"dimension. But received DOut[0] = [%d], Y[0] = [%d]."
,
...
@@ -323,6 +326,8 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -323,6 +326,8 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"DBias"
,
AddOutput
(
"DBias"
,
"The output grad tensor to bias of Out = (Act(X) * Y) + bias."
)
"The output grad tensor to bias of Out = (Act(X) * Y) + bias."
)
.
AsDispensable
();
.
AsDispensable
();
AddAttr
<
bool
>
(
"trans_x"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"trans_y"
,
""
).
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
AddAttr
<
std
::
string
>
(
"activation_grad"
,
"activation_grad"
,
...
...
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
浏览文件 @
0ed26e12
...
@@ -40,6 +40,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
...
@@ -40,6 +40,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
std
::
string
activation
=
ctx
.
Attr
<
std
::
string
>
(
"activation"
);
std
::
string
activation
=
ctx
.
Attr
<
std
::
string
>
(
"activation"
);
VLOG
(
10
)
<<
"trans_x = "
<<
trans_x
<<
" , trans_y = "
<<
trans_y
<<
" , activation = "
<<
activation
;
// activation = "none";
bool
enable_auxiliary
=
reserve_space
==
nullptr
?
false
:
true
;
bool
enable_auxiliary
=
reserve_space
==
nullptr
?
false
:
true
;
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -56,7 +59,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
...
@@ -56,7 +59,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
cublasComputeType_t
compute_type
=
CUBLAS_COMPUTE_32F
;
cublasComputeType_t
compute_type
=
CUBLAS_COMPUTE_32F
;
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
mat_type
=
CUDA_R_16F
;
mat_type
=
CUDA_R_16F
;
scale_type
=
CUDA_R_
16
F
;
scale_type
=
CUDA_R_
32
F
;
}
}
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
mat_type
=
CUDA_R_64F
;
mat_type
=
CUDA_R_64F
;
...
@@ -106,10 +109,12 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
...
@@ -106,10 +109,12 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
aux_data
,
sizeof
(
aux_data
)));
&
aux_data
,
sizeof
(
aux_data
)));
// int64_t aux_ld = trans_y ? K : N;
int64_t
aux_ld
=
N
;
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
N
,
operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
aux_ld
,
sizeof
(
N
)));
sizeof
(
aux_ld
)));
}
}
cublasLtMatrixLayout_t
x_desc
=
NULL
,
y_desc
=
NULL
,
out_desc
=
NULL
;
cublasLtMatrixLayout_t
x_desc
=
NULL
,
y_desc
=
NULL
,
out_desc
=
NULL
;
...
@@ -129,7 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
...
@@ -129,7 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
&
out_desc
,
mat_type
,
N
,
M
,
N
));
&
out_desc
,
mat_type
,
N
,
M
,
N
));
cublasLtHandle_t
lt_handle
=
dev_ctx
.
cublaslt_handle
();
cublasLtHandle_t
lt_handle
=
dev_ctx
.
cublaslt_handle
();
size_t
workspace_size
=
4
*
1024
*
1024
;
size_t
workspace_size
=
static_cast
<
size_t
>
(
4
)
*
102
4
*
1024
*
1024
;
const
cublasLtMatmulAlgo_t
*
algo
=
nullptr
;
const
cublasLtMatmulAlgo_t
*
algo
=
nullptr
;
cudaStream_t
stream
=
dev_ctx
.
stream
();
cudaStream_t
stream
=
dev_ctx
.
stream
();
memory
::
allocation
::
AllocationPtr
workspace
=
memory
::
allocation
::
AllocationPtr
workspace
=
...
@@ -192,20 +197,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -192,20 +197,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
std
::
string
activation_grad
=
ctx
.
Attr
<
std
::
string
>
(
"activation_grad"
);
std
::
string
activation_grad
=
ctx
.
Attr
<
std
::
string
>
(
"activation_grad"
);
auto
dout_mat_dims
=
bool
transpose_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
phi
::
flatten_to_2d
(
dout
->
dims
(),
dout
->
dims
().
size
()
-
1
);
bool
transpose_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
auto
x_mat_dims
=
phi
::
flatten_to_2d
(
x
->
dims
(),
x
->
dims
().
size
()
-
1
);
int64_t
M
=
x_mat_dims
[
0
];
VLOG
(
10
)
<<
"trans_x = "
<<
transpose_x
<<
" , trans_y = "
<<
transpose_y
int64_t
K
=
y
->
dims
()[
0
];
<<
" , activation_grad = "
<<
activation_grad
;
int64_t
N
=
y
->
dims
()[
1
];
// activation_grad = "none";
auto
x_mat_dims
=
phi
::
flatten_to_2d
(
x
->
dims
(),
transpose_x
?
1
:
x
->
dims
().
size
()
-
1
);
int64_t
M
=
transpose_x
?
x_mat_dims
[
1
]
:
x_mat_dims
[
0
];
int64_t
K
=
transpose_y
?
y
->
dims
()[
1
]
:
y
->
dims
()[
0
];
int64_t
N
=
transpose_y
?
y
->
dims
()[
0
]
:
y
->
dims
()[
1
];
cudaDataType_t
mat_type
=
CUDA_R_32F
;
cudaDataType_t
mat_type
=
CUDA_R_32F
;
cudaDataType_t
scale_type
=
CUDA_R_32F
;
cudaDataType_t
scale_type
=
CUDA_R_32F
;
cublasComputeType_t
compute_type
=
CUBLAS_COMPUTE_32F
;
cublasComputeType_t
compute_type
=
CUBLAS_COMPUTE_32F
;
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
mat_type
=
CUDA_R_16F
;
mat_type
=
CUDA_R_16F
;
scale_type
=
CUDA_R_
16
F
;
scale_type
=
CUDA_R_
32
F
;
}
}
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
mat_type
=
CUDA_R_64F
;
mat_type
=
CUDA_R_64F
;
...
@@ -214,7 +226,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -214,7 +226,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
}
cublasLtHandle_t
lt_handle
=
dev_ctx
.
cublaslt_handle
();
cublasLtHandle_t
lt_handle
=
dev_ctx
.
cublaslt_handle
();
size_t
workspace_size
=
4
*
1024
*
1024
;
size_t
workspace_size
=
static_cast
<
size_t
>
(
4
)
*
102
4
*
1024
*
1024
;
const
cublasLtMatmulAlgo_t
*
algo
=
nullptr
;
const
cublasLtMatmulAlgo_t
*
algo
=
nullptr
;
cudaStream_t
stream
=
dev_ctx
.
stream
();
cudaStream_t
stream
=
dev_ctx
.
stream
();
...
@@ -229,16 +241,54 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -229,16 +241,54 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
beta
=
&
beta32
;
beta
=
&
beta32
;
}
}
cublasOperation_t
trans_dout
=
CUBLAS_OP_N
;
cublasLtMatrixLayout_t
dout_desc
=
nullptr
,
dout_trans_desc
=
nullptr
;
cublasLtMatrixLayout_t
dout_desc
=
NULL
;
if
(
dx
)
{
cublasOperation_t
trans_dout
=
transpose_x
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
trans_y
=
(
transpose_x
^
transpose_y
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasLtMatrixLayout_t
dout_desc_for_dx
,
y_desc
,
dx_desc
;
if
(
trans_dout
==
CUBLAS_OP_T
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dout_trans_desc
,
mat_type
,
M
,
N
,
M
));
dout_desc_for_dx
=
dout_trans_desc
;
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dout_desc
,
mat_type
,
N
,
M
,
N
));
dout_desc_for_dx
=
dout_desc
;
}
if
(
transpose_y
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
y_desc
,
mat_type
,
K
,
N
,
K
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
y_desc
,
mat_type
,
N
,
K
,
N
));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dout_desc
,
mat_type
,
N
,
M
,
N
));
&
dx_desc
,
mat_type
,
K
,
M
,
K
));
if
(
dx
)
{
cublasLtMatmulDesc_t
dx_operation_desc
=
NULL
;
cublasLtMatmulDesc_t
dx_operation_desc
=
NULL
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescCreate
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescCreate
(
&
dx_operation_desc
,
compute_type
,
scale_type
));
&
dx_operation_desc
,
compute_type
,
scale_type
));
cublasOperation_t
trans_y
=
CUBLAS_OP_T
;
if
(
transpose_x
)
{
// dx = B * dout
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_dout
,
sizeof
(
trans_dout
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_y
,
sizeof
(
trans_y
)));
}
else
{
// dx = dout * B
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_dout
,
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_dout
,
...
@@ -247,6 +297,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -247,6 +297,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_y
,
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_y
,
sizeof
(
trans_y
)));
sizeof
(
trans_y
)));
}
cublasLtEpilogue_t
epiloque_func_for_dx
=
cublasLtEpilogue_t
epiloque_func_for_dx
=
get_epilogue_type_
(
activation_grad
);
get_epilogue_type_
(
activation_grad
);
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
...
@@ -260,18 +312,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -260,18 +312,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
aux_data
,
sizeof
(
aux_data
)));
&
aux_data
,
sizeof
(
aux_data
)));
int64_t
aux_ld
=
transpose_x
?
M
:
K
;
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
N
,
dx_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
sizeof
(
N
)));
&
aux_ld
,
sizeof
(
aux_ld
)));
}
}
cublasLtMatrixLayout_t
y_desc
=
NULL
,
dx_desc
=
NULL
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
y_desc
,
mat_type
,
N
,
K
,
N
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dx_desc
,
mat_type
,
K
,
M
,
K
));
memory
::
allocation
::
AllocationPtr
dx_workspace
=
memory
::
allocation
::
AllocationPtr
dx_workspace
=
memory
::
Alloc
(
dev_ctx
,
workspace_size
);
memory
::
Alloc
(
dev_ctx
,
workspace_size
);
...
@@ -284,10 +331,41 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -284,10 +331,41 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
}
if
(
dy
)
{
if
(
dy
)
{
cublasOperation_t
trans_dout
=
transpose_y
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
trans_x
=
(
transpose_x
^
transpose_y
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasLtMatrixLayout_t
dout_desc_for_dx
;
if
(
trans_dout
==
CUBLAS_OP_T
)
{
if
(
dout_trans_desc
==
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dout_trans_desc
,
mat_type
,
M
,
N
,
M
));
}
dout_desc_for_dx
=
dout_trans_desc
;
}
else
{
if
(
dout_desc
==
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dout_desc
,
mat_type
,
N
,
M
,
N
));
}
dout_desc_for_dx
=
dout_desc
;
}
cublasLtMatmulDesc_t
dy_operation_desc
=
NULL
;
cublasLtMatmulDesc_t
dy_operation_desc
=
NULL
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescCreate
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescCreate
(
&
dy_operation_desc
,
compute_type
,
scale_type
));
&
dy_operation_desc
,
compute_type
,
scale_type
));
cublasOperation_t
trans_x
=
CUBLAS_OP_T
;
if
(
transpose_y
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_dout
,
sizeof
(
trans_dout
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_x
,
sizeof
(
trans_x
)));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_dout
,
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
trans_dout
,
...
@@ -296,9 +374,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -296,9 +374,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_x
,
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
trans_x
,
sizeof
(
trans_x
)));
sizeof
(
trans_x
)));
cublasLtEpilogue_t
epiloque_func_for_dy
=
dbias
==
nullptr
}
?
CUBLASLT_EPILOGUE_DEFAULT
:
CUBLASLT_EPILOGUE_BGRADA
;
cublasLtEpilogue_t
epiloque_func_for_dy
=
dbias
==
nullptr
?
CUBLASLT_EPILOGUE_DEFAULT
:
(
transpose_y
?
CUBLASLT_EPILOGUE_BGRADB
:
CUBLASLT_EPILOGUE_BGRADA
);
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
platform
::
dynload
::
cublasLtMatmulDescSetAttribute
(
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
dy_operation_desc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
...
@@ -314,8 +396,16 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -314,8 +396,16 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
}
cublasLtMatrixLayout_t
x_desc
=
NULL
,
dy_desc
=
NULL
;
cublasLtMatrixLayout_t
x_desc
=
NULL
,
dy_desc
=
NULL
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
if
(
transpose_x
)
{
&
x_desc
,
mat_type
,
K
,
M
,
K
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
x_desc
,
mat_type
,
M
,
K
,
M
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
x_desc
,
mat_type
,
K
,
M
,
K
));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cublasLtMatrixLayoutCreate
(
&
dy_desc
,
mat_type
,
N
,
K
,
N
));
&
dy_desc
,
mat_type
,
N
,
K
,
N
));
...
...
python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py
浏览文件 @
0ed26e12
...
@@ -58,6 +58,14 @@ class MultiFCLayer(paddle.nn.Layer):
...
@@ -58,6 +58,14 @@ class MultiFCLayer(paddle.nn.Layer):
self
.
relu3
=
Activation
()
self
.
relu3
=
Activation
()
def
forward
(
self
,
x
,
matmul_y
,
ele_y
):
def
forward
(
self
,
x
,
matmul_y
,
ele_y
):
x
=
self
.
linear1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
relu2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
relu3
(
x
)
return
x
'''
output = self.linear1(x)
output = self.linear1(x)
output = self.relu1(output)
output = self.relu1(output)
output = self.linear2(output)
output = self.linear2(output)
...
@@ -71,8 +79,10 @@ class MultiFCLayer(paddle.nn.Layer):
...
@@ -71,8 +79,10 @@ class MultiFCLayer(paddle.nn.Layer):
output = self.relu3(output)
output = self.relu3(output)
output = paddle.add(output, output1)
output = paddle.add(output, output1)
return output
return output
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
"core is not compiled with CUDA")
class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
...
@@ -218,6 +228,7 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
...
@@ -218,6 +228,7 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
@@ -327,6 +338,7 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
...
@@ -327,6 +338,7 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
return
paddle
.
nn
.
ReLU
,
"relu"
,
"relu_grad"
return
paddle
.
nn
.
ReLU
,
"relu"
,
"relu_grad"
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
...
@@ -339,8 +351,8 @@ class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
...
@@ -339,8 +351,8 @@ class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
def test_output(self):
def test_output(self):
self._test_output()
self._test_output()
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
...
@@ -355,6 +367,7 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
...
@@ -355,6 +367,7 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
@@ -371,6 +384,7 @@ class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase):
...
@@ -371,6 +384,7 @@ class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase):
self
.
_test_output
()
self
.
_test_output
()
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
...
@@ -385,7 +399,7 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
...
@@ -385,7 +399,7 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
...
python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py
浏览文件 @
0ed26e12
...
@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
...
@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py
浏览文件 @
0ed26e12
...
@@ -446,5 +446,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
...
@@ -446,5 +446,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
unittest
.
main
()
unittest
.
main
()
python/paddle/nn/functional/common.py
浏览文件 @
0ed26e12
...
@@ -1470,7 +1470,7 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8):
...
@@ -1470,7 +1470,7 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8):
return
cos_sim
return
cos_sim
def
linear
(
x
,
weight
,
bias
=
None
,
name
=
None
):
def
linear
(
x
,
weight
,
bias
=
None
,
name
=
None
,
weight_transpose
=
False
):
r
"""
r
"""
Fully-connected linear transformation operator. For each input :math:`X` ,
Fully-connected linear transformation operator. For each input :math:`X` ,
...
@@ -1523,7 +1523,7 @@ def linear(x, weight, bias=None, name=None):
...
@@ -1523,7 +1523,7 @@ def linear(x, weight, bias=None, name=None):
"""
"""
if
in_dynamic_mode
():
if
in_dynamic_mode
():
pre_bias
=
_C_ops
.
matmul_v2
(
x
,
weight
,
'trans_x'
,
False
,
'trans_y'
,
pre_bias
=
_C_ops
.
matmul_v2
(
x
,
weight
,
'trans_x'
,
False
,
'trans_y'
,
Fal
se
)
weight_transpo
se
)
if
bias
is
None
:
if
bias
is
None
:
return
pre_bias
return
pre_bias
...
@@ -1538,7 +1538,7 @@ def linear(x, weight, bias=None, name=None):
...
@@ -1538,7 +1538,7 @@ def linear(x, weight, bias=None, name=None):
check_dtype
(
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
weight
]}
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
weight
]}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
Fal
se
}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
weight_transpo
se
}
tmp
=
helper
.
create_variable_for_type_inference
(
dtype
)
tmp
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
tmp
},
attrs
=
attrs
)
type
=
'matmul_v2'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
tmp
},
attrs
=
attrs
)
...
...
python/paddle/nn/layer/common.py
浏览文件 @
0ed26e12
...
@@ -150,13 +150,15 @@ class Linear(Layer):
...
@@ -150,13 +150,15 @@ class Linear(Layer):
out_features
,
out_features
,
weight_attr
=
None
,
weight_attr
=
None
,
bias_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
name
=
None
,
weight_transpose
=
False
):
super
(
Linear
,
self
).
__init__
()
super
(
Linear
,
self
).
__init__
()
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_weight_attr
=
weight_attr
self
.
_weight_attr
=
weight_attr
self
.
_bias_attr
=
bias_attr
self
.
_bias_attr
=
bias_attr
self
.
weight
=
self
.
create_parameter
(
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
out_features
],
shape
=
[
out_features
,
in_features
]
if
weight_transpose
else
[
in_features
,
out_features
],
attr
=
self
.
_weight_attr
,
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
is_bias
=
False
)
...
@@ -165,11 +167,16 @@ class Linear(Layer):
...
@@ -165,11 +167,16 @@ class Linear(Layer):
attr
=
self
.
_bias_attr
,
attr
=
self
.
_bias_attr
,
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
is_bias
=
True
)
self
.
weight_transpose
=
weight_transpose
self
.
name
=
name
self
.
name
=
name
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
F
.
linear
(
out
=
F
.
linear
(
x
=
input
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
name
=
self
.
name
)
x
=
input
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
name
=
self
.
name
,
weight_transpose
=
self
.
weight_transpose
)
return
out
return
out
def
extra_repr
(
self
):
def
extra_repr
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录