diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index fff78dd872c99f2b9ec9a5998ce43f2f6ad8e40b..94753f8dd38e09a4038986fac2f6103896fe59c8 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -120,6 +120,7 @@ message BuildStrategy { optional bool fix_op_run_order = 13 [ default = false ]; optional bool allow_cuda_graph_capture = 14 [ default = false ]; optional int32 reduce_strategy = 15 [ default = 0 ]; + optional bool fuse_gemm_epilogue = 16 [ default = false ]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc index f48224cbdc24fe9706a3c4eae029c6dc35381ad2..b72a63d37853c04ead4547ae4c384c5282c8abcd 100644 --- a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc +++ b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc @@ -22,6 +22,12 @@ namespace paddle { namespace framework { namespace ir { +static void GetTransposeAttrsFromOp(const OpDesc &op, bool *trans_x, + bool *trans_y) { + *trans_x = BOOST_GET_CONST(bool, op.GetAttr("trans_x")); + *trans_y = BOOST_GET_CONST(bool, op.GetAttr("trans_y")); +} + void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const { EpiloguePassActivationCache cache; @@ -75,6 +81,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) return; + bool trans_x, trans_y; + GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y); + OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); std::string activation = "none"; fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); @@ -85,6 +94,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("op_role", matmul_op_desc->GetAttr("op_role")); + fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x); + fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node); @@ -154,6 +165,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( auto activation = act_op->Op()->Type(); + bool trans_x, trans_y; + GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y); + OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()}); @@ -163,6 +177,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("op_role", matmul_op_desc->GetAttr("op_role")); + fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x); + fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); @@ -274,6 +290,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, matmul_grad_op_desc)) return; + bool trans_x, trans_y; + GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y); + OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); std::string activation_grad = "none"; fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); @@ -292,6 +311,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, activation_grad); fused_gemm_epilogue_grad_op_desc.SetAttr( "op_role", matmul_grad_op_desc->GetAttr("op_role")); + fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x); + fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y); auto gemm_epilogue_grad_node = g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); @@ -394,6 +415,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( auto activation_grad = act_grad_op->Op()->Type(); + bool trans_x, trans_y; + GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); fused_gemm_epilogue_grad_op_desc.SetInput("DOut", @@ -410,6 +433,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( activation_grad); fused_gemm_epilogue_grad_op_desc.SetAttr( "op_role", matmul_grad_op_desc->GetAttr("op_role")); + fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x); + fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y); auto gemm_epilogue_grad_node = g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); @@ -456,10 +481,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_( if (tmp_vec.size() > 0) return false; } } - if (BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_x")) || - BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_y"))) - return false; - return true; } diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index 4c4e3661e6d6edc5ea95b77cd283cc99afcca8ed..7cb6777e5a79ac55d15bf369cd00957904541b01 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -208,6 +209,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); + auto trans_x = ctx->Attrs().Get("trans_x"); + auto trans_y = ctx->Attrs().Get("trans_y"); + PADDLE_ENFORCE_GE( dout_dims.size(), 2, platform::errors::InvalidArgument( @@ -242,14 +246,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); PADDLE_ENFORCE_EQ( - dout_mat_dims[1], y_dims[1], + dout_mat_dims[1], trans_y ? y_dims[0] : y_dims[1], platform::errors::InvalidArgument( "The last dimension of DOut should be equal with Y's last" "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", dout_mat_dims[1], y_dims[1])); PADDLE_ENFORCE_EQ( - dout_mat_dims[0], x_mat_dims[0], + dout_mat_dims[0], trans_x ? x_mat_dims[1] : x_mat_dims[0], platform::errors::InvalidArgument( "The first dimension of DOut should be equal with X's first" "dimension. But received DOut[0] = [%d], Y[0] = [%d].", @@ -288,7 +292,7 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput("DBias")) { std::vector dbias_dims; - dbias_dims.push_back(y_dims[1]); + dbias_dims.push_back(trans_y ? y_dims[0] : y_dims[1]); ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims)); } } @@ -323,6 +327,20 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("DBias", "The output grad tensor to bias of Out = (Act(X) * Y) + bias.") .AsDispensable(); + AddAttr( + "trans_x", + R"DOC((bool, default false), Whether to transpose input tensor X + or not. The input tensor X coulbe be more than two dimension. When + set trans_x=true, it would fully reverse X. For instant: X with shpae + [d0, d1, d2, d3] -> [d3, d2, d1, d0].)DOC") + .SetDefault(false); + AddAttr( + "trans_y", + R"DOC((bool, default false), Whether to transpose input tensor Y + or not. The input tensor Y should be two dimension. When + set trans_y=true, it would transpose Y. For instant: Y with shpae + [d0, d1] -> [d1, d0].)DOC") + .SetDefault(false); AddAttr( "activation_grad", @@ -343,11 +361,38 @@ X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3] } }; +template +class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + const auto& act_type = this->template Attr("activation"); + PADDLE_ENFORCE_EQ(act_type, "none", phi::errors::InvalidArgument( + "The activation should be none.")); + + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("DOut", this->OutputGrad("Out")); + + op->SetOutput("DX", this->InputGrad("X")); + op->SetOutput("DY", this->InputGrad("Y")); + op->SetOutput("DBias", this->InputGrad("Bias")); + + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fused_gemm_epilogue, ops::FusedGemmEpilogueOp, - ops::FusedGemmEpilogueOpMaker) +REGISTER_OPERATOR( + fused_gemm_epilogue, ops::FusedGemmEpilogueOp, + ops::FusedGemmEpilogueOpMaker, + ops::FusedGemmEpilogueOpGradMaker, + ops::FusedGemmEpilogueOpGradMaker); REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp, - ops::FusedGemmEpilogueGradOpMaker) + ops::FusedGemmEpilogueGradOpMaker); diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 9bf3d1a485efc71a19960525cb427ffb823eeefa..407cd2b974def8e8566672ea80bac9fd5d7491ba 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/float16.h" @@ -41,6 +42,8 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { bool trans_y = ctx.Attr("trans_y"); std::string activation = ctx.Attr("activation"); + VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y + << " , activation = " << activation; bool enable_auxiliary = reserve_space == nullptr ? false : true; out->mutable_data(ctx.GetPlace()); @@ -48,6 +51,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { auto x_mat_dims = phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); + // (M * K) * (K * N) int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; @@ -106,10 +110,11 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { platform::dynload::cublasLtMatmulDescSetAttribute( operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_data, sizeof(aux_data))); + int64_t aux_ld = N; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, - sizeof(N))); + operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld, + sizeof(aux_ld))); } cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; @@ -129,8 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { &out_desc, mat_type, N, M, N)); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = 4 * 1024 * 1024; - + size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; cudaStream_t stream = dev_ctx.stream(); memory::allocation::AllocationPtr workspace = memory::Alloc(dev_ctx, workspace_size); @@ -149,13 +153,13 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { const auto* y_data = y->data(); const auto* x_data = x->data(); - cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( lt_handle, operation_desc, y_desc, x_desc, out_desc, alpha, beta, y_data, x_data, out_data, stream, workspace->ptr(), workspace_size); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( lt_handle, operation_desc, alpha, y_data, y_desc, x_data, x_desc, beta, - out_data, out_desc, out_data, out_desc, &algo, workspace->ptr(), + out_data, out_desc, out_data, out_desc, algo, workspace->ptr(), workspace_size, stream)); PADDLE_ENFORCE_GPU_SUCCESS( @@ -191,12 +195,94 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { } }; +enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; + +template +struct FusedGEMMGradTrait; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = false; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = false; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = true; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = true; +}; + +static constexpr auto BoolToCuBlasEnum(bool transpose) { + return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; +} + template class FusedGemmEpilogueGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); + bool transpose_x = ctx.Attr("trans_x"); + bool transpose_y = ctx.Attr("trans_y"); + if (transpose_x) { + if (transpose_y) { + ComputeImpl(ctx); + } else { + ComputeImpl(ctx); + } + } else { + if (transpose_y) { + ComputeImpl(ctx); + } else { + ComputeImpl(ctx); + } + } + } + + private: + template + static void ComputeImpl(const framework::ExecutionContext& ctx) { + using Trait = FusedGEMMGradTrait; + auto& dev_ctx = ctx.template device_context(); const Tensor* dout = ctx.Input("DOut"); const Tensor* x = ctx.Input("X"); const Tensor* y = ctx.Input("Y"); @@ -208,13 +294,18 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { std::string activation_grad = ctx.Attr("activation_grad"); - auto dout_mat_dims = - phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1); - auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1); + VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY + << " , activation_grad = " << activation_grad; + + auto x_mat_dims = + phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1); + + // (M * K) * (K * N) + int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0]; + int64_t K = TransY ? y->dims()[1] : y->dims()[0]; + int64_t N = TransY ? y->dims()[0] : y->dims()[1]; - int64_t M = x_mat_dims[0]; - int64_t K = y->dims()[0]; - int64_t N = y->dims()[1]; + VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N; cudaDataType_t mat_type = CUDA_R_32F; cudaDataType_t scale_type = CUDA_R_32F; @@ -229,7 +320,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { } cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = 4 * 1024 * 1024; + size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; + const cublasLtMatmulAlgo_t* algo = nullptr; cudaStream_t stream = dev_ctx.stream(); double alpha64 = 1.0, beta64 = 0.0; @@ -243,24 +335,81 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { beta = &beta32; } - cublasOperation_t trans_dout = CUBLAS_OP_N; - cublasLtMatrixLayout_t dout_desc = NULL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &dout_desc, mat_type, N, M, N)); + cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr; + cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr; + cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr; + cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr; + cublasLtMatmulDesc_t dx_operation_desc = nullptr, + dy_operation_desc = nullptr; + + DEFINE_PADDLE_SCOPE_GUARD([&] { + auto descs = {dout_desc, dout_trans_desc, x_desc, x_trans_desc, + y_desc, y_trans_desc, dx_desc, dy_desc}; + for (auto desc : descs) { + if (desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(desc)); + } + } + if (dx_operation_desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); + } + + if (dy_operation_desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); + } + }); + + auto x_row = TransX ? K : M; + auto x_col = TransX ? M : K; + auto y_row = TransY ? N : K; + auto y_col = TransY ? K : N; + auto z_row = TransX ? N : M; + auto z_col = TransX ? M : N; + + // dx = func(dout, y) if (dx) { - cublasLtMatmulDesc_t dx_operation_desc = NULL; + constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ); + cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc; + + if (TransX) { + dx_dout_desc = &dout_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dx_dout_desc, mat_type, z_row, z_col, z_row)); + } else { + dx_dout_desc = &dout_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dx_dout_desc, mat_type, z_col, z_row, z_col)); + } + + dx_y_desc = &y_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dx_y_desc, mat_type, y_col, y_row, y_col)); + + auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc); + auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc); + auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans); + auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dx_desc, mat_type, x_col, x_row, x_col)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( &dx_operation_desc, compute_type, scale_type)); - cublasOperation_t trans_y = CUBLAS_OP_T; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout, - sizeof(trans_dout))); + dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &a_trans, + sizeof(a_trans))); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y, - sizeof(trans_y))); + dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &b_trans, + sizeof(b_trans))); + cublasLtEpilogue_t epiloque_func_for_dx = get_epilogue_type_(activation_grad); PADDLE_ENFORCE_GPU_SUCCESS( @@ -274,105 +423,116 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { platform::dynload::cublasLtMatmulDescSetAttribute( dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_data, sizeof(aux_data))); + int64_t aux_ld = TransX ? M : K; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &K, - sizeof(K))); + dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &aux_ld, sizeof(aux_ld))); } - cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &y_desc, mat_type, N, K, N)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &dx_desc, mat_type, K, M, K)); - - memory::allocation::AllocationPtr dx_workspace = - memory::Alloc(dev_ctx, workspace_size); + auto dx_workspace = memory::Alloc(dev_ctx, workspace_size); - dx->mutable_data(ctx.GetPlace()); - auto* dx_data = dx->data(); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); const auto* y_data = y->data(); const auto* dout_data = dout->data(); + const auto* a_data = kXGradAIsDZ ? dout_data : y_data; + const auto* b_data = kXGradAIsDZ ? y_data : dout_data; - cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( - lt_handle, dx_operation_desc, y_desc, dout_desc, dx_desc, alpha, beta, - y_data, dout_data, dx_data, stream, dx_workspace->ptr(), - workspace_size); + auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, dx_operation_desc, b_desc, a_desc, dx_desc, alpha, beta, + b_data, a_data, dx_data, stream, dx_workspace->ptr(), workspace_size); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( - lt_handle, dx_operation_desc, alpha, y->data(), y_desc, - dout->data(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc, - &algo, dx_workspace->ptr(), workspace_size, stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(dx_desc)); + lt_handle, dx_operation_desc, alpha, b_data, b_desc, a_data, a_desc, + beta, dx_data, dx_desc, dx_data, dx_desc, algo, dx_workspace->ptr(), + workspace_size, stream)); } + // dy = func(dout, x) if (dy) { - cublasLtMatmulDesc_t dy_operation_desc = NULL; + constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ); + + cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr; + if (TransX) { + dy_dout_desc = &dout_trans_desc; + if (dout_trans_desc == nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_row, z_col, z_row)); + } + } else { + dy_dout_desc = &dout_desc; + if (dout_desc == nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_col, z_row, z_col)); + } + } + + dy_x_desc = &x_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dy_x_desc, mat_type, x_col, x_row, x_col)); + + auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc); + auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc); + auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans); + auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dy_desc, mat_type, y_col, y_row, y_col)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( &dy_operation_desc, compute_type, scale_type)); - cublasOperation_t trans_x = CUBLAS_OP_T; + PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout, - sizeof(trans_dout))); + dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &a_trans, + sizeof(a_trans))); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x, - sizeof(trans_x))); - cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr - ? CUBLASLT_EPILOGUE_DEFAULT - : CUBLASLT_EPILOGUE_BGRADA; + dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &b_trans, + sizeof(b_trans))); + + cublasLtEpilogue_t epiloque_func_for_dy; + if (dbias == nullptr) { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT; + } else { + if (TransY) { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB; + } else { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA; + } + } + PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func_for_dy, sizeof(epiloque_func_for_dy))); if (dbias) { - dbias->mutable_data(ctx.GetPlace()); - auto* dbias_data = dbias->data(); + auto* dbias_data = dbias->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &dbias_data, sizeof(dbias_data))); } - cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc, mat_type, K, M, K)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &dy_desc, mat_type, N, K, N)); - - memory::allocation::AllocationPtr dy_workspace = - memory::Alloc(dev_ctx, workspace_size); - - dy->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); + auto dy_workspace = memory::Alloc(dev_ctx, workspace_size); + auto* dy_data = dy->mutable_data(ctx.GetPlace()); const auto* dout_data = dout->data(); const auto* x_data = x->data(); + const auto* a_data = kYGradAIsDZ ? dout_data : x_data; + const auto* b_data = kYGradAIsDZ ? x_data : dout_data; - cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( - lt_handle, dy_operation_desc, dout_desc, x_desc, dy_desc, alpha, beta, - dout_data, x_data, dy_data, stream, dy_workspace->ptr(), - workspace_size); + auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, dy_operation_desc, b_desc, a_desc, dy_desc, alpha, beta, + b_data, a_data, dy_data, stream, dy_workspace->ptr(), workspace_size); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( - lt_handle, dy_operation_desc, alpha, dout_data, dout_desc, x_data, - x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, &algo, - dy_workspace->ptr(), workspace_size, stream)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(dy_desc)); + lt_handle, dy_operation_desc, alpha, b_data, b_desc, a_data, a_desc, + beta, dy_data, dy_desc, dy_data, dy_desc, algo, dy_workspace->ptr(), + workspace_size, stream)); } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(dout_desc)); } private: diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h index c90a6966fe0a841dd3eb692aaafcdd03535b16a0..8ff41b2c9616bbda80bf9be7fd3e8d9556560c86 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h @@ -21,7 +21,9 @@ limitations under the License. */ #include #include "gflags/gflags.h" #include "paddle/fluid/platform/dynload/cublasLt.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/utils/optional.h" DECLARE_int64(cublaslt_exhaustive_search_times); @@ -39,12 +41,14 @@ class GemmEpilogueAlgoCache { GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete; void operator=(GemmEpilogueAlgoCache const &) = delete; - cublasLtMatmulAlgo_t GetGemmAlgo( + cublasLtMatmulAlgo_t *GetGemmAlgo( cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc, cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc, cublasLtMatrixLayout_t c_desc, const void *alpha, const void *beta, const void *a, const void *b, void *c, cudaStream_t stream, void *workspace, size_t workspace_size) { + if (search_times_ <= 0) return nullptr; + int64_t seed = 0; std::hash hash_fn; @@ -54,132 +58,108 @@ class GemmEpilogueAlgoCache { HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); cublasLtMatmulAlgo_t ret; - auto it = map_.end(); - bool have_found = false; { std::lock_guard lock(cache_mutex_); - it = map_.find(seed); - + auto it = map_.find(seed); if (it != map_.end()) { - ret = it->second; - have_found = true; + return &(it->second); } } - if (!have_found) { - cublasLtMatmulPreference_t preference; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, sizeof(workspace_size))); - - int returned_results = 0; - cublasLtMatmulHeuristicResult_t heuristic_results[requested_algo_count_] = - {0}; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulAlgoGetHeuristic( - lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference, - requested_algo_count_, heuristic_results, &returned_results)); - - PADDLE_ENFORCE_GT( - returned_results, 0, - platform::errors::Unavailable("No GEMM epilogue algorithm support!")); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceDestroy(preference)); - - if (search_times_ > 0) { - int best_algo_idx = -1; - float best_algo_time = 0; - - // Run 100 times for warmup - int warmup_algo_idx = 0; - for (int t = 0; t < 100; t++) { - cublasStatus_t status = platform::dynload::cublasLtMatmul( - lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, - c, c_desc, &heuristic_results[warmup_algo_idx].algo, workspace, - workspace_size, stream); - if (status != CUBLAS_STATUS_SUCCESS) { - t = -1; - warmup_algo_idx += 1; - if (warmup_algo_idx == requested_algo_count_) { - PADDLE_THROW(platform::errors::Unavailable( - "No GEMM epilogue algorithm support!")); - } - } - } + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size))); - cudaEvent_t start_event, stop_event; - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); - - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - float curr_time = 0; - for (int check_idx = 0; check_idx < search_times_; check_idx++) { - float time = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); - - cublasStatus_t status = platform::dynload::cublasLtMatmul( - lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, - c_desc, c, c_desc, &heuristic_results[algo_idx].algo, workspace, - workspace_size, stream); - - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventElapsedTime(&time, start_event, stop_event)); - curr_time += time; - if (status != CUBLAS_STATUS_SUCCESS) { - curr_time = 3.40282e+038; // Max Value of float - break; - } - } - - curr_time = curr_time / search_times_; - if (curr_time < best_algo_time || algo_idx == 0) { - best_algo_idx = algo_idx; - best_algo_time = curr_time; - } - } + int returned_results = 0; + std::vector heuristic_results( + requested_algo_count_); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulAlgoGetHeuristic( + lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference, + requested_algo_count_, heuristic_results.data(), + &returned_results)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event)); + PADDLE_ENFORCE_GT( + returned_results, 0, + platform::errors::Unavailable("No GEMM epilogue algorithm support!")); - if (best_algo_idx == -1) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceDestroy(preference)); + + int best_algo_idx = -1; + float best_algo_time = 0; + + // Run 100 times for warmup + int warmup_algo_idx = 0; + for (int t = 0; t < 100; t++) { + cublasStatus_t status = platform::dynload::cublasLtMatmul( + lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, c, + c_desc, &heuristic_results[warmup_algo_idx].algo, workspace, + workspace_size, stream); + if (status != CUBLAS_STATUS_SUCCESS) { + t = -1; + warmup_algo_idx += 1; + if (warmup_algo_idx == requested_algo_count_) { PADDLE_THROW(platform::errors::Unavailable( "No GEMM epilogue algorithm support!")); } + } + } - ret = heuristic_results[best_algo_idx].algo; - } else { - int decided_algo_idx = -1; - for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { - cublasStatus_t status = platform::dynload::cublasLtMatmul( - lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, - c, c_desc, &heuristic_results[algo_idx].algo, workspace, - workspace_size, stream); - if (status == CUBLAS_STATUS_SUCCESS) { - decided_algo_idx = algo_idx; - break; - } - } - if (decided_algo_idx == -1) { - PADDLE_THROW(platform::errors::Unavailable( - "No GEMM epilogue algorithm support!")); + cudaEvent_t start_event, stop_event; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); + + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + float curr_time = 0; + for (int check_idx = 0; check_idx < search_times_; check_idx++) { + float time = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); + + cublasStatus_t status = platform::dynload::cublasLtMatmul( + lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, c, + c_desc, &heuristic_results[algo_idx].algo, workspace, + workspace_size, stream); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventElapsedTime(&time, start_event, stop_event)); + curr_time += time; + if (status != CUBLAS_STATUS_SUCCESS) { + curr_time = 3.40282e+038; // Max Value of float + break; } - ret = heuristic_results[decided_algo_idx].algo; } - std::lock_guard lock(cache_mutex_); - map_[seed] = ret; + curr_time = curr_time / search_times_; + if (curr_time < best_algo_time || algo_idx == 0) { + best_algo_idx = algo_idx; + best_algo_time = curr_time; + } } - VLOG(4) << "Search time:" << search_times_ << ", Is hash-key (" << seed - << ") found in GemmEpilogueAlgoCache? " << have_found; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event)); + + if (best_algo_idx == -1) { + PADDLE_THROW( + platform::errors::Unavailable("No GEMM epilogue algorithm support!")); + } + + ret = heuristic_results[best_algo_idx].algo; + + VLOG(4) << "Search time:" << search_times_ << ", hash-key (" << seed + << ") not found in GemmEpilogueAlgoCache"; - return ret; + std::lock_guard lock(cache_mutex_); + auto &algo_in_map = map_[seed]; + algo_in_map = ret; + return &algo_in_map; } private: diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 55297ed516ffb4f2e64abb44030b642785f03cbd..2756eac990ed328134c87808cb2a08137e829d86 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -101,6 +101,9 @@ def apply_build_strategy(main_program, startup_program, build_strategy, if build_strategy.enable_auto_fusion and use_cuda: apply_pass("fusion_group_pass") build_strategy.enable_auto_fusion = False + if build_strategy.fuse_gemm_epilogue: + apply_pass("fuse_gemm_epilogue_pass") + build_strategy.fuse_gemm_epilogue = False if build_strategy.fuse_elewise_add_act_ops: apply_pass("fuse_elewise_add_act_pass") build_strategy.fuse_elewise_add_act_ops = False diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index f7a3dfa1102b295096662f71d965e6aa48406a5a..fe1dbf3b92743d364e11e09d57142f2c5300289f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -263,6 +263,18 @@ def skip_check_grad_ci(reason=None): return wrapper +def skip_check_inplace_ci(reason=None): + if not isinstance(reason, str): + raise AssertionError( + "The reason for skipping check_inplace is required.") + + def wrapper(cls): + cls.no_need_check_inplace = True + return cls + + return wrapper + + def copy_bits_from_float_to_uint16(f): return struct.unpack('> 16 @@ -1288,6 +1300,9 @@ class OpTest(unittest.TestCase): Returns: None """ + if getattr(self, "no_need_check_inplace", False): + return + has_infer_inplace = fluid.core.has_infer_inplace(self.op_type) has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type) diff --git a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py index 2ea1bf2e9cb8105280a4f2635279518d125a4312..106ce5b4ef055311a5ba511c0c0b90612e410fbe 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py @@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16): if __name__ == "__main__": + paddle.enable_static() np.random.seed(0) unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py index f826898f9e5dd601b54eaeb1c54216414a70246b..4256945a1e8d572f35cdabb1579bc8ffbfc38644 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py @@ -19,7 +19,7 @@ import unittest import numpy as np import paddle import paddle.fluid.core as core -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest, skip_check_grad_ci, skip_check_inplace_ci def gelu(x): @@ -43,10 +43,15 @@ def get_output(X, Y, bias, act): return out +@skip_check_inplace_ci(reason="no inplace op") +class TestFuseGemmBase(OpTest): + pass + + @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMMFP16(OpTest): +class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -95,7 +100,7 @@ class TestFuseGemmEpilogueOpReluMMFP64(TestFuseGemmEpilogueOpReluMMFP16): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMTMFP16(OpTest): +class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -144,7 +149,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64(TestFuseGemmEpilogueOpReluMTMFP16): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMMTFP16(OpTest): +class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -193,7 +198,7 @@ class TestFuseGemmEpilogueOpReluMMTFP64(TestFuseGemmEpilogueOpReluMMTFP16): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMTMTFP16(OpTest): +class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -242,7 +247,7 @@ class TestFuseGemmEpilogueOpReluMTMTFP64(TestFuseGemmEpilogueOpReluMTMTFP16): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(OpTest): +class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -294,7 +299,7 @@ class TestFuseGemmEpilogueOpReluMMFP64MultiDimX( @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(OpTest): +class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -346,7 +351,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64MultiDimX( @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpGeluMMFP16(OpTest): +class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -397,7 +402,7 @@ class TestFuseGemmEpilogueOpGeluMMFP64(TestFuseGemmEpilogueOpGeluMMFP16): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestFuseGemmEpilogueOpNoneMMFP16(OpTest): +class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase): def setUp(self): self.op_type = "fused_gemm_epilogue" self.place = core.CUDAPlace(0) @@ -446,5 +451,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16): if __name__ == "__main__": + paddle.enable_static() np.random.seed(0) unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_matmul_bias.py b/python/paddle/fluid/tests/unittests/test_fused_matmul_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..98548c9996588938a96d431b11cd9443d3f968ae --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_matmul_bias.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid.core as core +import unittest +import numpy as np +from paddle.incubate.nn.functional import fused_matmul_bias, fused_linear +from paddle.incubate.nn import FusedLinear + + +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(core.ops, 'fused_gemm_epilogue') + else: + return False + + +def matmul(x, y, bias, trans_x, trans_y): + x = np.array(x) + if trans_x: + x = np.ascontiguousarray(np.transpose(x)) + if trans_y: + y = np.ascontiguousarray(np.transpose(y)) + z = np.matmul(x, y) + if bias is None: + return z + else: + return z + bias + + +def matmul_grad(x, y, bias, dz, trans_x, trans_y): + if trans_x: + if trans_y: + dx = matmul(y, dz, None, True, True) + dy = matmul(dz, x, None, True, True) + else: + dx = matmul(y, dz, None, False, True) + dy = matmul(x, dz, None, False, False) + else: + if trans_y: + dx = matmul(dz, y, None, False, False) + dy = matmul(dz, x, None, True, False) + else: + dx = matmul(dz, y, None, False, True) + dy = matmul(x, dz, None, True, False) + if bias is None: + dbias = None + else: + dbias = np.sum(dz, axis=0, keepdims=False) + return dx, dy, dbias + + +@unittest.skipIf( + not is_fused_matmul_bias_supported(), + "fused_gemm_epilogue is only supported when CUDA version >= 11.6") +class TestFusedMatmulBias(unittest.TestCase): + def setUp(self): + paddle.set_device('gpu') + + def rand_data(self, shape, dtype): + return np.random.randint(low=-20, high=20, size=shape).astype(dtype) + + def rand_test_base(self, m, n, k, trans_x, trans_y, need_bias, dtype, seed): + np.random.seed(seed) + x_shape = [k, m] if trans_x else [m, k] + y_shape = [n, k] if trans_y else [k, n] + bias_shape = [n] + + x_np = self.rand_data(x_shape, dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + + y_np = self.rand_data(y_shape, dtype) + y = paddle.to_tensor(y_np) + y.stop_gradient = False + + if need_bias: + bias_np = self.rand_data(bias_shape, dtype) + bias = paddle.to_tensor(bias_np) + bias.stop_gradient = False + else: + bias_np = None + bias = None + + z = fused_matmul_bias(x, y, bias, trans_x, trans_y) + z_np = matmul(x_np, y_np, bias_np, trans_x, trans_y) + self.assertTrue(np.array_equal(z.numpy(), z_np)) + + z_grad_np = self.rand_data(z_np.shape, dtype) + paddle.autograd.backward(z, grad_tensors=[paddle.to_tensor(z_grad_np)]) + + x_grad_np, y_grad_np, bias_grad_np = matmul_grad( + x_np, y_np, bias_np, z_grad_np, trans_x, trans_y) + self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_np)) + self.assertEqual(y_grad_np.shape, y_np.shape) + self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_np)) + + if need_bias: + self.assertTrue(np.array_equal(bias.grad.numpy(), bias_grad_np)) + else: + self.assertTrue(bias_grad_np is None) + + def rand_test(self, m, n, k, dtype): + seed = int(np.random.randint(low=0, high=1000, size=[1])) + for trans_x in [False, True]: + for trans_y in [False, True]: + for need_bias in [False, True]: + self.rand_test_base(m, n, k, trans_x, trans_y, need_bias, + dtype, seed) + + def test_fp32(self): + self.rand_test(30, 40, 50, np.float32) + + def test_fp16(self): + self.rand_test(4, 5, 7, np.float16) + + +@unittest.skipIf( + not is_fused_matmul_bias_supported(), + "fused_gemm_epilogue is only supported when CUDA version >= 11.6") +class TestFusedLinear(unittest.TestCase): + def check_fused_linear(self, transpose): + x = paddle.randn([30, 40]) + linear = FusedLinear(40, 50, transpose_weight=transpose) + y1 = linear(x) + y2 = fused_linear(x, linear.weight, linear.bias, transpose) + self.assertTrue(np.array_equal(y1.numpy(), y2.numpy())) + + def test_non_transpose(self): + self.check_fused_linear(False) + + def test_transpose(self): + self.check_fused_linear(True) + + +@unittest.skipIf( + not is_fused_matmul_bias_supported(), + "fused_gemm_epilogue is only supported when CUDA version >= 11.6") +class TestStaticGraph(unittest.TestCase): + def test_static_graph(self): + paddle.enable_static() + x = paddle.static.data(name='x', dtype='float32', shape=[-1, 100]) + linear = FusedLinear(100, 300) + y = linear(x) + self.assertEqual(list(y.shape), [-1, 300]) + paddle.disable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 3c806aa646ebe3157ab06819c12829806f502aa0..cf15ee7d8ffaa321b2700c38b2dbea8682ad0a3f 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -16,6 +16,7 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 from .layer.fused_transformer import FusedFeedForward # noqa: F401 from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 +from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 __all__ = [ #noqa @@ -23,5 +24,6 @@ __all__ = [ #noqa 'FusedFeedForward', 'FusedTransformerEncoderLayer', 'FusedMultiTransformer', + 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', ] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 02e44548ce5d87a4be505dc6a2981405ee3cc938..e9894990455abf65972457ce67cdfeb164711b2c 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -15,11 +15,14 @@ from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_feedforward from .fused_transformer import fused_multi_transformer +from .fused_matmul_bias import fused_matmul_bias, fused_linear from .fused_transformer import fused_bias_dropout_residual_layer_norm __all__ = [ 'fused_multi_head_attention', 'fused_feedforward', 'fused_multi_transformer', + 'fused_matmul_bias', + 'fused_linear', 'fused_bias_dropout_residual_layer_norm', ] diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc2e62144589164779b7f927742e72ae0ee770d --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode +from paddle.tensor.linalg import matmul +from paddle import _C_ops + + +def fused_matmul_bias(x, + y, + bias=None, + transpose_x=False, + transpose_y=False, + name=None): + """ + Applies matrix multiplication of two tensors and then bias addition if provided. + This method requires CUDA version >= 11.6. + + Args: + x (Tensor): the first input Tensor to be multiplied. + y (Tensor): the second input Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, the bias is added to the matrix multiplication result. + transpose_x (bool): Whether to transpose :math:`x` before multiplication. + transpose_y (bool): Whether to transpose :math:`y` before multiplication. + name(str|None): For detailed information, please refer to + :ref:`api_guide_Name` . Usually name is no need to set and None by default. + + Returns: + Tensor: the output Tensor. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_matmul_bias + + x = paddle.randn([3, 4]) + y = paddle.randn([4, 5]) + bias = paddle.randn([5]) + out = fused_matmul_bias(x, y, bias) + print(out.shape) # [3, 5] + """ + if bias is None: + return matmul(x, y, transpose_x, transpose_y, name) + if _non_static_mode(): + return _C_ops.fused_gemm_epilogue(x, y, bias, 'trans_x', transpose_x, + 'trans_y', transpose_y) + + helper = LayerHelper('fused_matmul_bias', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='fused_gemm_epilogue', + inputs={'X': x, + 'Y': y, + 'Bias': bias}, + outputs={'Out': out}, + attrs={'trans_x': transpose_x, + 'trans_y': transpose_y}) + return out + + +def fused_linear(x, weight, bias=None, transpose_weight=False, name=None): + """ + Fully-connected linear transformation operator. This method requires CUDA version >= 11.6. + + Args: + x (Tensor): the input Tensor to be multiplied. + weight (Tensor): the weight Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, the bias is added to the matrix multiplication result. + transpose_weight (bool): Whether to transpose :math:`weight` before multiplication. + name(str|None): For detailed information, please refer to + :ref:`api_guide_Name` . Usually name is no need to set and None by default. + + Returns: + Tensor: the output Tensor. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_linear + + x = paddle.randn([3, 4]) + weight = paddle.randn([4, 5]) + bias = paddle.randn([5]) + out = fused_linear(x, weight, bias) + print(out.shape) # [3, 5] + """ + return fused_matmul_bias(x, weight, bias, False, transpose_weight, name) diff --git a/python/paddle/incubate/nn/layer/fused_linear.py b/python/paddle/incubate/nn/layer/fused_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c872c3993cf747d302394129186d9f2cc779f5 --- /dev/null +++ b/python/paddle/incubate/nn/layer/fused_linear.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.nn import Layer +from paddle.incubate.nn import functional as F + + +class FusedLinear(Layer): + """ + Linear layer takes only one multi-dimensional tensor as input with the + shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any + number of additional dimensions. It multiplies input tensor with the weight + (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces + an output tensor of shape :math:`[batch\_size, *, out\_features]` . + If :math:`bias\_attr` is not False, the bias (a 1-D tensor of + shape :math:`[out\_features]` ) will be created and added to the output. + + Parameters: + in_features (int): The number of input units. + out_features (int): The number of output units. + weight_attr (ParamAttr, optional): The attribute for the learnable + weight of this layer. The default value is None and the weight will be + initialized to zero. For detailed information, please refer to + paddle.ParamAttr. + transpose_weight (bool): Whether to transpose the `weight` Tensor before + multiplication. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias + of this layer. If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to paddle.ParamAttr. The default value is None and the bias will be + initialized to zero. + name (str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + + Attribute: + **weight** (Parameter): the learnable weight of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Shape: + - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . + - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn import FusedLinear + + x = paddle.randn([3, 4]) + linear = FusedLinear(4, 5) + y = linear(x) + print(y.shape) # [3, 5] + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + transpose_weight=False, + name=None): + super(FusedLinear, self).__init__() + if transpose_weight: + weight_shape = [out_features, in_features] + else: + weight_shape = [in_features, out_features] + dtype = self._helper.get_default_dtype() + self.weight = self.create_parameter( + shape=weight_shape, attr=weight_attr, dtype=dtype, is_bias=False) + self.bias = self.create_parameter( + shape=[out_features], attr=bias_attr, dtype=dtype, is_bias=True) + self.transpose_weight = transpose_weight + self.name = name + + def forward(self, input): + return F.fused_linear(input, self.weight, self.bias, + self.transpose_weight, self.name)