提交 0ed26e12 编写于 作者: R root

support weight transpose

上级 60b86b2f
...@@ -118,6 +118,7 @@ message BuildStrategy { ...@@ -118,6 +118,7 @@ message BuildStrategy {
optional bool fix_op_run_order = 13 [ default = false ]; optional bool fix_op_run_order = 13 [ default = false ];
optional bool allow_cuda_graph_capture = 14 [ default = false ]; optional bool allow_cuda_graph_capture = 14 [ default = false ];
optional int32 reduce_strategy = 15 [ default = 0 ]; optional int32 reduce_strategy = 15 [ default = 0 ];
optional bool fuse_gemm_epilogue = 16 [ default = false ];
} }
message ExecutionStrategy { message ExecutionStrategy {
......
...@@ -18,18 +18,28 @@ ...@@ -18,18 +18,28 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
PADDLE_DEFINE_EXPORTED_bool(enable_gemm_fwd_fusion, true, "");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void GetTransposeAttrsFromOp(const OpDesc &op, bool *trans_x,
bool *trans_y) {
*trans_x = BOOST_GET_CONST(bool, op.GetAttr("trans_x"));
*trans_y = BOOST_GET_CONST(bool, op.GetAttr("trans_y"));
}
void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const { void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const {
EpiloguePassActivationCache cache; EpiloguePassActivationCache cache;
if (FLAGS_enable_gemm_fwd_fusion) {
graph = FuseLinearActFwd(graph, {"relu", "gelu"}, false, false, &cache); graph = FuseLinearActFwd(graph, {"relu", "gelu"}, false, false, &cache);
graph = FuseLinearActFwd(graph, {"relu"}, true, true, &cache); graph = FuseLinearActFwd(graph, {"relu"}, true, true, &cache);
graph = FuseLinearActFwd(graph, {"gelu"}, true, false, &cache); graph = FuseLinearActFwd(graph, {"gelu"}, true, false, &cache);
graph = FuseLinearFwd(graph, false); graph = FuseLinearFwd(graph, false);
graph = FuseLinearFwd(graph, true); graph = FuseLinearFwd(graph, true);
}
graph = FuseLinearActBwd(graph, {"relu_grad"}, true, &cache); graph = FuseLinearActBwd(graph, {"relu_grad"}, true, &cache);
graph = FuseLinearActBwd(graph, {"gelu_grad"}, false, &cache); graph = FuseLinearActBwd(graph, {"gelu_grad"}, false, &cache);
graph = FuseLinearBwd(graph, false); graph = FuseLinearBwd(graph, false);
...@@ -75,6 +85,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -75,6 +85,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc))
return; return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
std::string activation = "none"; std::string activation = "none";
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
...@@ -85,6 +98,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -85,6 +98,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role", fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role")); matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node); IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node);
...@@ -154,6 +169,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -154,6 +169,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
auto activation = act_op->Op()->Type(); auto activation = act_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()}); fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()});
...@@ -163,6 +181,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -163,6 +181,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role", fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role")); matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
...@@ -274,6 +294,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -274,6 +294,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
matmul_grad_op_desc)) matmul_grad_op_desc))
return; return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
std::string activation_grad = "none"; std::string activation_grad = "none";
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
...@@ -292,6 +315,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -292,6 +315,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
activation_grad); activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr( fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -394,6 +419,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -394,6 +419,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
auto activation_grad = act_grad_op->Op()->Type(); auto activation_grad = act_grad_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
fused_gemm_epilogue_grad_op_desc.SetInput("DOut", fused_gemm_epilogue_grad_op_desc.SetInput("DOut",
...@@ -410,6 +437,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -410,6 +437,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
activation_grad); activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr( fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -456,10 +485,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_( ...@@ -456,10 +485,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_(
if (tmp_vec.size() > 0) return false; if (tmp_vec.size() > 0) return false;
} }
} }
if (BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_x")) ||
BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_y")))
return false;
return true; return true;
} }
......
...@@ -208,6 +208,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -208,6 +208,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
dout_dims.size(), 2, dout_dims.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -242,14 +245,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -242,14 +245,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dout_mat_dims[1], y_dims[1], dout_mat_dims[1], trans_y ? y_dims[0] : y_dims[1],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last" "The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].", "dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1], y_dims[1])); dout_mat_dims[1], y_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dout_mat_dims[0], x_mat_dims[0], dout_mat_dims[0], trans_x ? x_mat_dims[1] : x_mat_dims[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The first dimension of DOut should be equal with X's first" "The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d].", "dimension. But received DOut[0] = [%d], Y[0] = [%d].",
...@@ -323,6 +326,8 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -323,6 +326,8 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("DBias", AddOutput("DBias",
"The output grad tensor to bias of Out = (Act(X) * Y) + bias.") "The output grad tensor to bias of Out = (Act(X) * Y) + bias.")
.AsDispensable(); .AsDispensable();
AddAttr<bool>("trans_x", "").SetDefault(false);
AddAttr<bool>("trans_y", "").SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"activation_grad", "activation_grad",
......
...@@ -40,6 +40,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -40,6 +40,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation"); std::string activation = ctx.Attr<std::string>("activation");
VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
<< " , activation = " << activation;
// activation = "none";
bool enable_auxiliary = reserve_space == nullptr ? false : true; bool enable_auxiliary = reserve_space == nullptr ? false : true;
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -56,7 +59,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -56,7 +59,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) { if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F; mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F; scale_type = CUDA_R_32F;
} }
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F; mat_type = CUDA_R_64F;
...@@ -106,10 +109,12 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -106,10 +109,12 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data))); &aux_data, sizeof(aux_data)));
// int64_t aux_ld = trans_y ? K : N;
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
sizeof(N))); sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
...@@ -129,7 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -129,7 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
&out_desc, mat_type, N, M, N)); &out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr; const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::allocation::AllocationPtr workspace =
...@@ -192,20 +197,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -192,20 +197,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
std::string activation_grad = ctx.Attr<std::string>("activation_grad"); std::string activation_grad = ctx.Attr<std::string>("activation_grad");
auto dout_mat_dims = bool transpose_x = ctx.Attr<bool>("trans_x");
phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1); bool transpose_y = ctx.Attr<bool>("trans_y");
auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1);
int64_t M = x_mat_dims[0]; VLOG(10) << "trans_x = " << transpose_x << " , trans_y = " << transpose_y
int64_t K = y->dims()[0]; << " , activation_grad = " << activation_grad;
int64_t N = y->dims()[1];
// activation_grad = "none";
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), transpose_x ? 1 : x->dims().size() - 1);
int64_t M = transpose_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = transpose_y ? y->dims()[1] : y->dims()[0];
int64_t N = transpose_y ? y->dims()[0] : y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F; cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F; cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) { if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F; mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F; scale_type = CUDA_R_32F;
} }
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F; mat_type = CUDA_R_64F;
...@@ -214,7 +226,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -214,7 +226,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} }
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr; const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
...@@ -229,16 +241,54 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -229,16 +241,54 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
beta = &beta32; beta = &beta32;
} }
cublasOperation_t trans_dout = CUBLAS_OP_N; cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t dout_desc = NULL; if (dx) {
cublasOperation_t trans_dout = transpose_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t trans_y =
(transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasLtMatrixLayout_t dout_desc_for_dx, y_desc, dx_desc;
if (trans_dout == CUBLAS_OP_T) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
mat_type, M, N, M));
dout_desc_for_dx = dout_trans_desc;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc, mat_type,
N, M, N));
dout_desc_for_dx = dout_desc;
}
if (transpose_y) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K,
N, K));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N,
K, N));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dout_desc, mat_type, N, M, N)); &dx_desc, mat_type, K, M, K));
if (dx) {
cublasLtMatmulDesc_t dx_operation_desc = NULL; cublasLtMatmulDesc_t dx_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type)); &dx_operation_desc, compute_type, scale_type));
cublasOperation_t trans_y = CUBLAS_OP_T;
if (transpose_x) {
// dx = B * dout
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_y,
sizeof(trans_y)));
} else {
// dx = dout * B
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout, dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
...@@ -247,6 +297,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -247,6 +297,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y, dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y,
sizeof(trans_y))); sizeof(trans_y)));
}
cublasLtEpilogue_t epiloque_func_for_dx = cublasLtEpilogue_t epiloque_func_for_dx =
get_epilogue_type_(activation_grad); get_epilogue_type_(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -260,18 +312,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -260,18 +312,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data))); &aux_data, sizeof(aux_data)));
int64_t aux_ld = transpose_x ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
sizeof(N))); &aux_ld, sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, K, M, K));
memory::allocation::AllocationPtr dx_workspace = memory::allocation::AllocationPtr dx_workspace =
memory::Alloc(dev_ctx, workspace_size); memory::Alloc(dev_ctx, workspace_size);
...@@ -284,10 +331,41 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -284,10 +331,41 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} }
if (dy) { if (dy) {
cublasOperation_t trans_dout = transpose_y ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t trans_x =
(transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasLtMatrixLayout_t dout_desc_for_dx;
if (trans_dout == CUBLAS_OP_T) {
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
mat_type, M, N, M));
}
dout_desc_for_dx = dout_trans_desc;
} else {
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc,
mat_type, N, M, N));
}
dout_desc_for_dx = dout_desc;
}
cublasLtMatmulDesc_t dy_operation_desc = NULL; cublasLtMatmulDesc_t dy_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type)); &dy_operation_desc, compute_type, scale_type));
cublasOperation_t trans_x = CUBLAS_OP_T;
if (transpose_y) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_x,
sizeof(trans_x)));
} else {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout, dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
...@@ -296,9 +374,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -296,9 +374,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x, dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x,
sizeof(trans_x))); sizeof(trans_x)));
cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr }
? CUBLASLT_EPILOGUE_DEFAULT
: CUBLASLT_EPILOGUE_BGRADA; cublasLtEpilogue_t epiloque_func_for_dy =
dbias == nullptr ? CUBLASLT_EPILOGUE_DEFAULT
: (transpose_y ? CUBLASLT_EPILOGUE_BGRADB
: CUBLASLT_EPILOGUE_BGRADA);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
...@@ -314,8 +396,16 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -314,8 +396,16 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} }
cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL; cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( if (transpose_x) {
&x_desc, mat_type, K, M, K)); PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M,
K, M));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K,
M, K));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, N, K, N)); &dy_desc, mat_type, N, K, N));
......
...@@ -58,6 +58,14 @@ class MultiFCLayer(paddle.nn.Layer): ...@@ -58,6 +58,14 @@ class MultiFCLayer(paddle.nn.Layer):
self.relu3 = Activation() self.relu3 = Activation()
def forward(self, x, matmul_y, ele_y): def forward(self, x, matmul_y, ele_y):
x = self.linear1(x)
x = self.relu1(x)
x = self.linear2(x)
x = self.relu2(x)
x = self.linear3(x)
x = self.relu3(x)
return x
'''
output = self.linear1(x) output = self.linear1(x)
output = self.relu1(output) output = self.relu1(output)
output = self.linear2(output) output = self.linear2(output)
...@@ -71,8 +79,10 @@ class MultiFCLayer(paddle.nn.Layer): ...@@ -71,8 +79,10 @@ class MultiFCLayer(paddle.nn.Layer):
output = self.relu3(output) output = self.relu3(output)
output = paddle.add(output, output1) output = paddle.add(output, output1)
return output return output
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueFWDBase(unittest.TestCase): class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
...@@ -218,6 +228,7 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32): ...@@ -218,6 +228,7 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
self.data_arr = self.data_arr.astype("float16") self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16") self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16") self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -327,6 +338,7 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase): ...@@ -327,6 +338,7 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
return paddle.nn.ReLU, "relu", "relu_grad" return paddle.nn.ReLU, "relu", "relu_grad"
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase): class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
...@@ -339,8 +351,8 @@ class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase): ...@@ -339,8 +351,8 @@ class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
def test_output(self): def test_output(self):
self._test_output() self._test_output()
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32): class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
...@@ -355,6 +367,7 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32): ...@@ -355,6 +367,7 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
self.data_arr = self.data_arr.astype("float16") self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16") self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16") self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -371,6 +384,7 @@ class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase): ...@@ -371,6 +384,7 @@ class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase):
self._test_output() self._test_output()
'''
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32): class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
...@@ -385,7 +399,7 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32): ...@@ -385,7 +399,7 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
self.data_arr = self.data_arr.astype("float16") self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16") self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16") self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0) np.random.seed(0)
......
...@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16): ...@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0) np.random.seed(0)
unittest.main() unittest.main()
...@@ -446,5 +446,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16): ...@@ -446,5 +446,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0) np.random.seed(0)
unittest.main() unittest.main()
...@@ -1470,7 +1470,7 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8): ...@@ -1470,7 +1470,7 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8):
return cos_sim return cos_sim
def linear(x, weight, bias=None, name=None): def linear(x, weight, bias=None, name=None, weight_transpose=False):
r""" r"""
Fully-connected linear transformation operator. For each input :math:`X` , Fully-connected linear transformation operator. For each input :math:`X` ,
...@@ -1523,7 +1523,7 @@ def linear(x, weight, bias=None, name=None): ...@@ -1523,7 +1523,7 @@ def linear(x, weight, bias=None, name=None):
""" """
if in_dynamic_mode(): if in_dynamic_mode():
pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y', pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y',
False) weight_transpose)
if bias is None: if bias is None:
return pre_bias return pre_bias
...@@ -1538,7 +1538,7 @@ def linear(x, weight, bias=None, name=None): ...@@ -1538,7 +1538,7 @@ def linear(x, weight, bias=None, name=None):
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
inputs = {'X': [x], 'Y': [weight]} inputs = {'X': [x], 'Y': [weight]}
attrs = {'trans_x': False, 'trans_y': False} attrs = {'trans_x': False, 'trans_y': weight_transpose}
tmp = helper.create_variable_for_type_inference(dtype) tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs) type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
......
...@@ -150,13 +150,15 @@ class Linear(Layer): ...@@ -150,13 +150,15 @@ class Linear(Layer):
out_features, out_features,
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
name=None): name=None,
weight_transpose=False):
super(Linear, self).__init__() super(Linear, self).__init__()
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=[in_features, out_features], shape=[out_features, in_features]
if weight_transpose else [in_features, out_features],
attr=self._weight_attr, attr=self._weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
...@@ -165,11 +167,16 @@ class Linear(Layer): ...@@ -165,11 +167,16 @@ class Linear(Layer):
attr=self._bias_attr, attr=self._bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self.weight_transpose = weight_transpose
self.name = name self.name = name
def forward(self, input): def forward(self, input):
out = F.linear( out = F.linear(
x=input, weight=self.weight, bias=self.bias, name=self.name) x=input,
weight=self.weight,
bias=self.bias,
name=self.name,
weight_transpose=self.weight_transpose)
return out return out
def extra_repr(self): def extra_repr(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册