未验证 提交 048b0013 编写于 作者: S sneaxiy 提交者: GitHub

Make fuse_gemm_epilogue support transpose_x and transpose_y (#40558)

* support weight transpose

* add ut

* add template

* fix transpose error

* fix transpose_comment

* add api tests

* add skipif

* add doc
上级 07993044
...@@ -120,6 +120,7 @@ message BuildStrategy { ...@@ -120,6 +120,7 @@ message BuildStrategy {
optional bool fix_op_run_order = 13 [ default = false ]; optional bool fix_op_run_order = 13 [ default = false ];
optional bool allow_cuda_graph_capture = 14 [ default = false ]; optional bool allow_cuda_graph_capture = 14 [ default = false ];
optional int32 reduce_strategy = 15 [ default = 0 ]; optional int32 reduce_strategy = 15 [ default = 0 ];
optional bool fuse_gemm_epilogue = 16 [ default = false ];
} }
message ExecutionStrategy { message ExecutionStrategy {
......
...@@ -22,6 +22,12 @@ namespace paddle { ...@@ -22,6 +22,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void GetTransposeAttrsFromOp(const OpDesc &op, bool *trans_x,
bool *trans_y) {
*trans_x = BOOST_GET_CONST(bool, op.GetAttr("trans_x"));
*trans_y = BOOST_GET_CONST(bool, op.GetAttr("trans_y"));
}
void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const { void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const {
EpiloguePassActivationCache cache; EpiloguePassActivationCache cache;
...@@ -75,6 +81,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -75,6 +81,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc))
return; return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
std::string activation = "none"; std::string activation = "none";
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
...@@ -85,6 +94,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -85,6 +94,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role", fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role")); matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node); IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node);
...@@ -154,6 +165,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -154,6 +165,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
auto activation = act_op->Op()->Type(); auto activation = act_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()}); fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()});
...@@ -163,6 +177,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -163,6 +177,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
fused_gemm_epilogue_op_desc.SetAttr("activation", activation); fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role", fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role")); matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
...@@ -274,6 +290,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -274,6 +290,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
matmul_grad_op_desc)) matmul_grad_op_desc))
return; return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
std::string activation_grad = "none"; std::string activation_grad = "none";
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
...@@ -292,6 +311,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -292,6 +311,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
activation_grad); activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr( fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -394,6 +415,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -394,6 +415,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
auto activation_grad = act_grad_op->Op()->Type(); auto activation_grad = act_grad_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
fused_gemm_epilogue_grad_op_desc.SetInput("DOut", fused_gemm_epilogue_grad_op_desc.SetInput("DOut",
...@@ -410,6 +433,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -410,6 +433,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
activation_grad); activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr( fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -456,10 +481,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_( ...@@ -456,10 +481,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_(
if (tmp_vec.size() > 0) return false; if (tmp_vec.size() > 0) return false;
} }
} }
if (BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_x")) ||
BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_y")))
return false;
return true; return true;
} }
......
...@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
...@@ -208,6 +209,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -208,6 +209,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
dout_dims.size(), 2, dout_dims.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -242,14 +246,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -242,14 +246,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dout_mat_dims[1], y_dims[1], dout_mat_dims[1], trans_y ? y_dims[0] : y_dims[1],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last" "The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].", "dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1], y_dims[1])); dout_mat_dims[1], y_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dout_mat_dims[0], x_mat_dims[0], dout_mat_dims[0], trans_x ? x_mat_dims[1] : x_mat_dims[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The first dimension of DOut should be equal with X's first" "The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d].", "dimension. But received DOut[0] = [%d], Y[0] = [%d].",
...@@ -288,7 +292,7 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { ...@@ -288,7 +292,7 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput("DBias")) { if (ctx->HasOutput("DBias")) {
std::vector<int64_t> dbias_dims; std::vector<int64_t> dbias_dims;
dbias_dims.push_back(y_dims[1]); dbias_dims.push_back(trans_y ? y_dims[0] : y_dims[1]);
ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims)); ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims));
} }
} }
...@@ -323,6 +327,20 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -323,6 +327,20 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("DBias", AddOutput("DBias",
"The output grad tensor to bias of Out = (Act(X) * Y) + bias.") "The output grad tensor to bias of Out = (Act(X) * Y) + bias.")
.AsDispensable(); .AsDispensable();
AddAttr<bool>(
"trans_x",
R"DOC((bool, default false), Whether to transpose input tensor X
or not. The input tensor X coulbe be more than two dimension. When
set trans_x=true, it would fully reverse X. For instant: X with shpae
[d0, d1, d2, d3] -> [d3, d2, d1, d0].)DOC")
.SetDefault(false);
AddAttr<bool>(
"trans_y",
R"DOC((bool, default false), Whether to transpose input tensor Y
or not. The input tensor Y should be two dimension. When
set trans_y=true, it would transpose Y. For instant: Y with shpae
[d0, d1] -> [d1, d0].)DOC")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"activation_grad", "activation_grad",
...@@ -343,11 +361,38 @@ X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3] ...@@ -343,11 +361,38 @@ X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3]
} }
}; };
template <typename T>
class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
const auto& act_type = this->template Attr<std::string>("activation");
PADDLE_ENFORCE_EQ(act_type, "none", phi::errors::InvalidArgument(
"The activation should be none."));
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->OutputGrad("Out"));
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DY", this->InputGrad("Y"));
op->SetOutput("DBias", this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_gemm_epilogue, ops::FusedGemmEpilogueOp, REGISTER_OPERATOR(
ops::FusedGemmEpilogueOpMaker) fused_gemm_epilogue, ops::FusedGemmEpilogueOp,
ops::FusedGemmEpilogueOpMaker,
ops::FusedGemmEpilogueOpGradMaker<paddle::framework::OpDesc>,
ops::FusedGemmEpilogueOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp, REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp,
ops::FusedGemmEpilogueGradOpMaker) ops::FusedGemmEpilogueGradOpMaker);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -41,6 +42,8 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -41,6 +42,8 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation"); std::string activation = ctx.Attr<std::string>("activation");
VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
<< " , activation = " << activation;
bool enable_auxiliary = reserve_space == nullptr ? false : true; bool enable_auxiliary = reserve_space == nullptr ? false : true;
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -48,6 +51,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -48,6 +51,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
auto x_mat_dims = auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
// (M * K) * (K * N)
int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
...@@ -106,10 +110,11 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -106,10 +110,11 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data))); &aux_data, sizeof(aux_data)));
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
sizeof(N))); sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
...@@ -129,8 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -129,8 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
&out_desc, mat_type, N, M, N)); &out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx, workspace_size); memory::Alloc(dev_ctx, workspace_size);
...@@ -149,13 +153,13 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -149,13 +153,13 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
const auto* y_data = y->data<T>(); const auto* y_data = y->data<T>();
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, operation_desc, y_desc, x_desc, out_desc, alpha, beta, lt_handle, operation_desc, y_desc, x_desc, out_desc, alpha, beta,
y_data, x_data, out_data, stream, workspace->ptr(), workspace_size); y_data, x_data, out_data, stream, workspace->ptr(), workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, operation_desc, alpha, y_data, y_desc, x_data, x_desc, beta, lt_handle, operation_desc, alpha, y_data, y_desc, x_data, x_desc, beta,
out_data, out_desc, out_data, out_desc, &algo, workspace->ptr(), out_data, out_desc, out_data, out_desc, algo, workspace->ptr(),
workspace_size, stream)); workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -191,12 +195,94 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -191,12 +195,94 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
} }
}; };
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;
template <>
struct FusedGEMMGradTrait<false, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = false;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<false, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = false;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = true;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = true;
};
static constexpr auto BoolToCuBlasEnum(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
if (transpose_x) {
if (transpose_y) {
ComputeImpl<true, true>(ctx);
} else {
ComputeImpl<true, false>(ctx);
}
} else {
if (transpose_y) {
ComputeImpl<false, true>(ctx);
} else {
ComputeImpl<false, false>(ctx);
}
}
}
private:
template <bool TransX, bool TransY>
static void ComputeImpl(const framework::ExecutionContext& ctx) {
using Trait = FusedGEMMGradTrait<TransX, TransY>;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const Tensor* dout = ctx.Input<Tensor>("DOut"); const Tensor* dout = ctx.Input<Tensor>("DOut");
const Tensor* x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); const Tensor* y = ctx.Input<Tensor>("Y");
...@@ -208,13 +294,18 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -208,13 +294,18 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
std::string activation_grad = ctx.Attr<std::string>("activation_grad"); std::string activation_grad = ctx.Attr<std::string>("activation_grad");
auto dout_mat_dims = VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY
phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1); << " , activation_grad = " << activation_grad;
auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1);
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1);
// (M * K) * (K * N)
int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = TransY ? y->dims()[1] : y->dims()[0];
int64_t N = TransY ? y->dims()[0] : y->dims()[1];
int64_t M = x_mat_dims[0]; VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N;
int64_t K = y->dims()[0];
int64_t N = y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F; cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F; cudaDataType_t scale_type = CUDA_R_32F;
...@@ -229,7 +320,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -229,7 +320,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} }
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
double alpha64 = 1.0, beta64 = 0.0; double alpha64 = 1.0, beta64 = 0.0;
...@@ -243,24 +335,81 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -243,24 +335,81 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
beta = &beta32; beta = &beta32;
} }
cublasOperation_t trans_dout = CUBLAS_OP_N; cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t dout_desc = NULL; cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr;
&dout_desc, mat_type, N, M, N)); cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr;
cublasLtMatmulDesc_t dx_operation_desc = nullptr,
dy_operation_desc = nullptr;
DEFINE_PADDLE_SCOPE_GUARD([&] {
auto descs = {dout_desc, dout_trans_desc, x_desc, x_trans_desc,
y_desc, y_trans_desc, dx_desc, dy_desc};
for (auto desc : descs) {
if (desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(desc));
}
}
if (dx_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
}
if (dy_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
}
});
auto x_row = TransX ? K : M;
auto x_col = TransX ? M : K;
auto y_row = TransY ? N : K;
auto y_col = TransY ? K : N;
auto z_row = TransX ? N : M;
auto z_col = TransX ? M : N;
// dx = func(dout, y)
if (dx) { if (dx) {
cublasLtMatmulDesc_t dx_operation_desc = NULL; constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc;
if (TransX) {
dx_dout_desc = &dout_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_row, z_col, z_row));
} else {
dx_dout_desc = &dout_desc;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_col, z_row, z_col));
}
dx_y_desc = &y_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dx_y_desc, mat_type, y_col, y_row, y_col));
auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, x_col, x_row, x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type)); &dx_operation_desc, compute_type, scale_type));
cublasOperation_t trans_y = CUBLAS_OP_T;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout, dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &a_trans,
sizeof(trans_dout))); sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y, dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &b_trans,
sizeof(trans_y))); sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dx = cublasLtEpilogue_t epiloque_func_for_dx =
get_epilogue_type_(activation_grad); get_epilogue_type_(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -274,105 +423,116 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -274,105 +423,116 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data))); &aux_data, sizeof(aux_data)));
int64_t aux_ld = TransX ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &K, dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
sizeof(K))); &aux_ld, sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL; auto dx_workspace = memory::Alloc(dev_ctx, workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, K, M, K));
memory::allocation::AllocationPtr dx_workspace =
memory::Alloc(dev_ctx, workspace_size);
dx->mutable_data<T>(ctx.GetPlace()); auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->data<T>();
const auto* y_data = y->data<T>(); const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
const auto* b_data = kXGradAIsDZ ? y_data : dout_data;
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, dx_operation_desc, y_desc, dout_desc, dx_desc, alpha, beta, lt_handle, dx_operation_desc, b_desc, a_desc, dx_desc, alpha, beta,
y_data, dout_data, dx_data, stream, dx_workspace->ptr(), b_data, a_data, dx_data, stream, dx_workspace->ptr(), workspace_size);
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dx_operation_desc, alpha, y->data<T>(), y_desc, lt_handle, dx_operation_desc, alpha, b_data, b_desc, a_data, a_desc,
dout->data<T>(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc, beta, dx_data, dx_desc, dx_data, dx_desc, algo, dx_workspace->ptr(),
&algo, dx_workspace->ptr(), workspace_size, stream)); workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dx_desc));
} }
// dy = func(dout, x)
if (dy) { if (dy) {
cublasLtMatmulDesc_t dy_operation_desc = NULL; constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr;
if (TransX) {
dy_dout_desc = &dout_trans_desc;
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_row, z_col, z_row));
}
} else {
dy_dout_desc = &dout_desc;
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_col, z_row, z_col));
}
}
dy_x_desc = &x_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dy_x_desc, mat_type, x_col, x_row, x_col));
auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, y_col, y_row, y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type)); &dy_operation_desc, compute_type, scale_type));
cublasOperation_t trans_x = CUBLAS_OP_T;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout, dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &a_trans,
sizeof(trans_dout))); sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x, dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &b_trans,
sizeof(trans_x))); sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr
? CUBLASLT_EPILOGUE_DEFAULT cublasLtEpilogue_t epiloque_func_for_dy;
: CUBLASLT_EPILOGUE_BGRADA; if (dbias == nullptr) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT;
} else {
if (TransY) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB;
} else {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA;
}
}
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dy, sizeof(epiloque_func_for_dy))); &epiloque_func_for_dy, sizeof(epiloque_func_for_dy)));
if (dbias) { if (dbias) {
dbias->mutable_data<T>(ctx.GetPlace()); auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
auto* dbias_data = dbias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&dbias_data, sizeof(dbias_data))); &dbias_data, sizeof(dbias_data)));
} }
cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL; auto dy_workspace = memory::Alloc(dev_ctx, workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
&x_desc, mat_type, K, M, K));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, N, K, N));
memory::allocation::AllocationPtr dy_workspace =
memory::Alloc(dev_ctx, workspace_size);
dy->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
const auto* b_data = kYGradAIsDZ ? x_data : dout_data;
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, dy_operation_desc, dout_desc, x_desc, dy_desc, alpha, beta, lt_handle, dy_operation_desc, b_desc, a_desc, dy_desc, alpha, beta,
dout_data, x_data, dy_data, stream, dy_workspace->ptr(), b_data, a_data, dy_data, stream, dy_workspace->ptr(), workspace_size);
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dy_operation_desc, alpha, dout_data, dout_desc, x_data, lt_handle, dy_operation_desc, alpha, b_data, b_desc, a_data, a_desc,
x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, &algo, beta, dy_data, dy_desc, dy_data, dy_desc, algo, dy_workspace->ptr(),
dy_workspace->ptr(), workspace_size, stream)); workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dy_desc));
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dout_desc));
} }
private: private:
......
...@@ -21,7 +21,9 @@ limitations under the License. */ ...@@ -21,7 +21,9 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times); DECLARE_int64(cublaslt_exhaustive_search_times);
...@@ -39,12 +41,14 @@ class GemmEpilogueAlgoCache { ...@@ -39,12 +41,14 @@ class GemmEpilogueAlgoCache {
GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete; GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete;
void operator=(GemmEpilogueAlgoCache const &) = delete; void operator=(GemmEpilogueAlgoCache const &) = delete;
cublasLtMatmulAlgo_t GetGemmAlgo( cublasLtMatmulAlgo_t *GetGemmAlgo(
cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc, cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc,
cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc, cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc,
cublasLtMatrixLayout_t c_desc, const void *alpha, const void *beta, cublasLtMatrixLayout_t c_desc, const void *alpha, const void *beta,
const void *a, const void *b, void *c, cudaStream_t stream, const void *a, const void *b, void *c, cudaStream_t stream,
void *workspace, size_t workspace_size) { void *workspace, size_t workspace_size) {
if (search_times_ <= 0) return nullptr;
int64_t seed = 0; int64_t seed = 0;
std::hash<int64_t> hash_fn; std::hash<int64_t> hash_fn;
...@@ -54,132 +58,108 @@ class GemmEpilogueAlgoCache { ...@@ -54,132 +58,108 @@ class GemmEpilogueAlgoCache {
HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); HashMatrixLayoutDesc_(c_desc, &seed, hash_fn);
cublasLtMatmulAlgo_t ret; cublasLtMatmulAlgo_t ret;
auto it = map_.end();
bool have_found = false;
{ {
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(cache_mutex_);
it = map_.find(seed); auto it = map_.find(seed);
if (it != map_.end()) { if (it != map_.end()) {
ret = it->second; return &(it->second);
have_found = true;
} }
} }
if (!have_found) { cublasLtMatmulPreference_t preference;
cublasLtMatmulPreference_t preference; PADDLE_ENFORCE_GPU_SUCCESS(
PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulPreferenceCreate(&preference));
platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); PADDLE_ENFORCE_GPU_SUCCESS(
PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulPreferenceSetAttribute(
platform::dynload::cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
&workspace_size, sizeof(workspace_size)));
int returned_results = 0;
cublasLtMatmulHeuristicResult_t heuristic_results[requested_algo_count_] =
{0};
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulAlgoGetHeuristic(
lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference,
requested_algo_count_, heuristic_results, &returned_results));
PADDLE_ENFORCE_GT(
returned_results, 0,
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceDestroy(preference));
if (search_times_ > 0) {
int best_algo_idx = -1;
float best_algo_time = 0;
// Run 100 times for warmup
int warmup_algo_idx = 0;
for (int t = 0; t < 100; t++) {
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc,
c, c_desc, &heuristic_results[warmup_algo_idx].algo, workspace,
workspace_size, stream);
if (status != CUBLAS_STATUS_SUCCESS) {
t = -1;
warmup_algo_idx += 1;
if (warmup_algo_idx == requested_algo_count_) {
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!"));
}
}
}
cudaEvent_t start_event, stop_event; int returned_results = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); requested_algo_count_);
PADDLE_ENFORCE_GPU_SUCCESS(
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { platform::dynload::cublasLtMatmulAlgoGetHeuristic(
float curr_time = 0; lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference,
for (int check_idx = 0; check_idx < search_times_; check_idx++) { requested_algo_count_, heuristic_results.data(),
float time = 0; &returned_results));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c,
c_desc, c, c_desc, &heuristic_results[algo_idx].algo, workspace,
workspace_size, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventElapsedTime(&time, start_event, stop_event));
curr_time += time;
if (status != CUBLAS_STATUS_SUCCESS) {
curr_time = 3.40282e+038; // Max Value of float
break;
}
}
curr_time = curr_time / search_times_;
if (curr_time < best_algo_time || algo_idx == 0) {
best_algo_idx = algo_idx;
best_algo_time = curr_time;
}
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event)); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event)); returned_results, 0,
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
if (best_algo_idx == -1) { PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceDestroy(preference));
int best_algo_idx = -1;
float best_algo_time = 0;
// Run 100 times for warmup
int warmup_algo_idx = 0;
for (int t = 0; t < 100; t++) {
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, c,
c_desc, &heuristic_results[warmup_algo_idx].algo, workspace,
workspace_size, stream);
if (status != CUBLAS_STATUS_SUCCESS) {
t = -1;
warmup_algo_idx += 1;
if (warmup_algo_idx == requested_algo_count_) {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!")); "No GEMM epilogue algorithm support!"));
} }
}
}
ret = heuristic_results[best_algo_idx].algo; cudaEvent_t start_event, stop_event;
} else { PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event));
int decided_algo_idx = -1; PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event));
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
cublasStatus_t status = platform::dynload::cublasLtMatmul( for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, float curr_time = 0;
c, c_desc, &heuristic_results[algo_idx].algo, workspace, for (int check_idx = 0; check_idx < search_times_; check_idx++) {
workspace_size, stream); float time = 0;
if (status == CUBLAS_STATUS_SUCCESS) { PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
decided_algo_idx = algo_idx;
break; cublasStatus_t status = platform::dynload::cublasLtMatmul(
} lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, c,
} c_desc, &heuristic_results[algo_idx].algo, workspace,
if (decided_algo_idx == -1) { workspace_size, stream);
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!")); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventElapsedTime(&time, start_event, stop_event));
curr_time += time;
if (status != CUBLAS_STATUS_SUCCESS) {
curr_time = 3.40282e+038; // Max Value of float
break;
} }
ret = heuristic_results[decided_algo_idx].algo;
} }
std::lock_guard<std::mutex> lock(cache_mutex_); curr_time = curr_time / search_times_;
map_[seed] = ret; if (curr_time < best_algo_time || algo_idx == 0) {
best_algo_idx = algo_idx;
best_algo_time = curr_time;
}
} }
VLOG(4) << "Search time:" << search_times_ << ", Is hash-key (" << seed PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event));
<< ") found in GemmEpilogueAlgoCache? " << have_found; PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event));
if (best_algo_idx == -1) {
PADDLE_THROW(
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
}
ret = heuristic_results[best_algo_idx].algo;
VLOG(4) << "Search time:" << search_times_ << ", hash-key (" << seed
<< ") not found in GemmEpilogueAlgoCache";
return ret; std::lock_guard<std::mutex> lock(cache_mutex_);
auto &algo_in_map = map_[seed];
algo_in_map = ret;
return &algo_in_map;
} }
private: private:
......
...@@ -101,6 +101,9 @@ def apply_build_strategy(main_program, startup_program, build_strategy, ...@@ -101,6 +101,9 @@ def apply_build_strategy(main_program, startup_program, build_strategy,
if build_strategy.enable_auto_fusion and use_cuda: if build_strategy.enable_auto_fusion and use_cuda:
apply_pass("fusion_group_pass") apply_pass("fusion_group_pass")
build_strategy.enable_auto_fusion = False build_strategy.enable_auto_fusion = False
if build_strategy.fuse_gemm_epilogue:
apply_pass("fuse_gemm_epilogue_pass")
build_strategy.fuse_gemm_epilogue = False
if build_strategy.fuse_elewise_add_act_ops: if build_strategy.fuse_elewise_add_act_ops:
apply_pass("fuse_elewise_add_act_pass") apply_pass("fuse_elewise_add_act_pass")
build_strategy.fuse_elewise_add_act_ops = False build_strategy.fuse_elewise_add_act_ops = False
......
...@@ -263,6 +263,18 @@ def skip_check_grad_ci(reason=None): ...@@ -263,6 +263,18 @@ def skip_check_grad_ci(reason=None):
return wrapper return wrapper
def skip_check_inplace_ci(reason=None):
if not isinstance(reason, str):
raise AssertionError(
"The reason for skipping check_inplace is required.")
def wrapper(cls):
cls.no_need_check_inplace = True
return cls
return wrapper
def copy_bits_from_float_to_uint16(f): def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16 return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
...@@ -1288,6 +1300,9 @@ class OpTest(unittest.TestCase): ...@@ -1288,6 +1300,9 @@ class OpTest(unittest.TestCase):
Returns: Returns:
None None
""" """
if getattr(self, "no_need_check_inplace", False):
return
has_infer_inplace = fluid.core.has_infer_inplace(self.op_type) has_infer_inplace = fluid.core.has_infer_inplace(self.op_type)
has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type) has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type)
......
...@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16): ...@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0) np.random.seed(0)
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci, skip_check_inplace_ci
def gelu(x): def gelu(x):
...@@ -43,10 +43,15 @@ def get_output(X, Y, bias, act): ...@@ -43,10 +43,15 @@ def get_output(X, Y, bias, act):
return out return out
@skip_check_inplace_ci(reason="no inplace op")
class TestFuseGemmBase(OpTest):
pass
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP16(OpTest): class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -95,7 +100,7 @@ class TestFuseGemmEpilogueOpReluMMFP64(TestFuseGemmEpilogueOpReluMMFP16): ...@@ -95,7 +100,7 @@ class TestFuseGemmEpilogueOpReluMMFP64(TestFuseGemmEpilogueOpReluMMFP16):
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP16(OpTest): class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -144,7 +149,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64(TestFuseGemmEpilogueOpReluMTMFP16): ...@@ -144,7 +149,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64(TestFuseGemmEpilogueOpReluMTMFP16):
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMTFP16(OpTest): class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -193,7 +198,7 @@ class TestFuseGemmEpilogueOpReluMMTFP64(TestFuseGemmEpilogueOpReluMMTFP16): ...@@ -193,7 +198,7 @@ class TestFuseGemmEpilogueOpReluMMTFP64(TestFuseGemmEpilogueOpReluMMTFP16):
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMTFP16(OpTest): class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -242,7 +247,7 @@ class TestFuseGemmEpilogueOpReluMTMTFP64(TestFuseGemmEpilogueOpReluMTMTFP16): ...@@ -242,7 +247,7 @@ class TestFuseGemmEpilogueOpReluMTMTFP64(TestFuseGemmEpilogueOpReluMTMTFP16):
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(OpTest): class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -294,7 +299,7 @@ class TestFuseGemmEpilogueOpReluMMFP64MultiDimX( ...@@ -294,7 +299,7 @@ class TestFuseGemmEpilogueOpReluMMFP64MultiDimX(
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(OpTest): class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -346,7 +351,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64MultiDimX( ...@@ -346,7 +351,7 @@ class TestFuseGemmEpilogueOpReluMTMFP64MultiDimX(
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpGeluMMFP16(OpTest): class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -397,7 +402,7 @@ class TestFuseGemmEpilogueOpGeluMMFP64(TestFuseGemmEpilogueOpGeluMMFP16): ...@@ -397,7 +402,7 @@ class TestFuseGemmEpilogueOpGeluMMFP64(TestFuseGemmEpilogueOpGeluMMFP16):
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFuseGemmEpilogueOpNoneMMFP16(OpTest): class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase):
def setUp(self): def setUp(self):
self.op_type = "fused_gemm_epilogue" self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
...@@ -446,5 +451,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16): ...@@ -446,5 +451,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0) np.random.seed(0)
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid.core as core
import unittest
import numpy as np
from paddle.incubate.nn.functional import fused_matmul_bias, fused_linear
from paddle.incubate.nn import FusedLinear
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False
def matmul(x, y, bias, trans_x, trans_y):
x = np.array(x)
if trans_x:
x = np.ascontiguousarray(np.transpose(x))
if trans_y:
y = np.ascontiguousarray(np.transpose(y))
z = np.matmul(x, y)
if bias is None:
return z
else:
return z + bias
def matmul_grad(x, y, bias, dz, trans_x, trans_y):
if trans_x:
if trans_y:
dx = matmul(y, dz, None, True, True)
dy = matmul(dz, x, None, True, True)
else:
dx = matmul(y, dz, None, False, True)
dy = matmul(x, dz, None, False, False)
else:
if trans_y:
dx = matmul(dz, y, None, False, False)
dy = matmul(dz, x, None, True, False)
else:
dx = matmul(dz, y, None, False, True)
dy = matmul(x, dz, None, True, False)
if bias is None:
dbias = None
else:
dbias = np.sum(dz, axis=0, keepdims=False)
return dx, dy, dbias
@unittest.skipIf(
not is_fused_matmul_bias_supported(),
"fused_gemm_epilogue is only supported when CUDA version >= 11.6")
class TestFusedMatmulBias(unittest.TestCase):
def setUp(self):
paddle.set_device('gpu')
def rand_data(self, shape, dtype):
return np.random.randint(low=-20, high=20, size=shape).astype(dtype)
def rand_test_base(self, m, n, k, trans_x, trans_y, need_bias, dtype, seed):
np.random.seed(seed)
x_shape = [k, m] if trans_x else [m, k]
y_shape = [n, k] if trans_y else [k, n]
bias_shape = [n]
x_np = self.rand_data(x_shape, dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y_np = self.rand_data(y_shape, dtype)
y = paddle.to_tensor(y_np)
y.stop_gradient = False
if need_bias:
bias_np = self.rand_data(bias_shape, dtype)
bias = paddle.to_tensor(bias_np)
bias.stop_gradient = False
else:
bias_np = None
bias = None
z = fused_matmul_bias(x, y, bias, trans_x, trans_y)
z_np = matmul(x_np, y_np, bias_np, trans_x, trans_y)
self.assertTrue(np.array_equal(z.numpy(), z_np))
z_grad_np = self.rand_data(z_np.shape, dtype)
paddle.autograd.backward(z, grad_tensors=[paddle.to_tensor(z_grad_np)])
x_grad_np, y_grad_np, bias_grad_np = matmul_grad(
x_np, y_np, bias_np, z_grad_np, trans_x, trans_y)
self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_np))
self.assertEqual(y_grad_np.shape, y_np.shape)
self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_np))
if need_bias:
self.assertTrue(np.array_equal(bias.grad.numpy(), bias_grad_np))
else:
self.assertTrue(bias_grad_np is None)
def rand_test(self, m, n, k, dtype):
seed = int(np.random.randint(low=0, high=1000, size=[1]))
for trans_x in [False, True]:
for trans_y in [False, True]:
for need_bias in [False, True]:
self.rand_test_base(m, n, k, trans_x, trans_y, need_bias,
dtype, seed)
def test_fp32(self):
self.rand_test(30, 40, 50, np.float32)
def test_fp16(self):
self.rand_test(4, 5, 7, np.float16)
@unittest.skipIf(
not is_fused_matmul_bias_supported(),
"fused_gemm_epilogue is only supported when CUDA version >= 11.6")
class TestFusedLinear(unittest.TestCase):
def check_fused_linear(self, transpose):
x = paddle.randn([30, 40])
linear = FusedLinear(40, 50, transpose_weight=transpose)
y1 = linear(x)
y2 = fused_linear(x, linear.weight, linear.bias, transpose)
self.assertTrue(np.array_equal(y1.numpy(), y2.numpy()))
def test_non_transpose(self):
self.check_fused_linear(False)
def test_transpose(self):
self.check_fused_linear(True)
@unittest.skipIf(
not is_fused_matmul_bias_supported(),
"fused_gemm_epilogue is only supported when CUDA version >= 11.6")
class TestStaticGraph(unittest.TestCase):
def test_static_graph(self):
paddle.enable_static()
x = paddle.static.data(name='x', dtype='float32', shape=[-1, 100])
linear = FusedLinear(100, 300)
y = linear(x)
self.assertEqual(list(y.shape), [-1, 300])
paddle.disable_static()
if __name__ == "__main__":
unittest.main()
...@@ -16,6 +16,7 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 ...@@ -16,6 +16,7 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401 from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 from .layer.fused_transformer import FusedMultiTransformer # noqa: F401
from .layer.fused_linear import FusedLinear # noqa: F401
from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
...@@ -23,5 +24,6 @@ __all__ = [ #noqa ...@@ -23,5 +24,6 @@ __all__ = [ #noqa
'FusedFeedForward', 'FusedFeedForward',
'FusedTransformerEncoderLayer', 'FusedTransformerEncoderLayer',
'FusedMultiTransformer', 'FusedMultiTransformer',
'FusedLinear',
'FusedBiasDropoutResidualLayerNorm', 'FusedBiasDropoutResidualLayerNorm',
] ]
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_multi_head_attention
from .fused_transformer import fused_feedforward from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer from .fused_transformer import fused_multi_transformer
from .fused_matmul_bias import fused_matmul_bias, fused_linear
from .fused_transformer import fused_bias_dropout_residual_layer_norm from .fused_transformer import fused_bias_dropout_residual_layer_norm
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
'fused_feedforward', 'fused_feedforward',
'fused_multi_transformer', 'fused_multi_transformer',
'fused_matmul_bias',
'fused_linear',
'fused_bias_dropout_residual_layer_norm', 'fused_bias_dropout_residual_layer_norm',
] ]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle.tensor.linalg import matmul
from paddle import _C_ops
def fused_matmul_bias(x,
y,
bias=None,
transpose_x=False,
transpose_y=False,
name=None):
"""
Applies matrix multiplication of two tensors and then bias addition if provided.
This method requires CUDA version >= 11.6.
Args:
x (Tensor): the first input Tensor to be multiplied.
y (Tensor): the second input Tensor to be multiplied. Its rank must be 2.
bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would
be performed. Otherwise, the bias is added to the matrix multiplication result.
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_matmul_bias
x = paddle.randn([3, 4])
y = paddle.randn([4, 5])
bias = paddle.randn([5])
out = fused_matmul_bias(x, y, bias)
print(out.shape) # [3, 5]
"""
if bias is None:
return matmul(x, y, transpose_x, transpose_y, name)
if _non_static_mode():
return _C_ops.fused_gemm_epilogue(x, y, bias, 'trans_x', transpose_x,
'trans_y', transpose_y)
helper = LayerHelper('fused_matmul_bias', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='fused_gemm_epilogue',
inputs={'X': x,
'Y': y,
'Bias': bias},
outputs={'Out': out},
attrs={'trans_x': transpose_x,
'trans_y': transpose_y})
return out
def fused_linear(x, weight, bias=None, transpose_weight=False, name=None):
"""
Fully-connected linear transformation operator. This method requires CUDA version >= 11.6.
Args:
x (Tensor): the input Tensor to be multiplied.
weight (Tensor): the weight Tensor to be multiplied. Its rank must be 2.
bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would
be performed. Otherwise, the bias is added to the matrix multiplication result.
transpose_weight (bool): Whether to transpose :math:`weight` before multiplication.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_linear
x = paddle.randn([3, 4])
weight = paddle.randn([4, 5])
bias = paddle.randn([5])
out = fused_linear(x, weight, bias)
print(out.shape) # [3, 5]
"""
return fused_matmul_bias(x, weight, bias, False, transpose_weight, name)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Layer
from paddle.incubate.nn import functional as F
class FusedLinear(Layer):
"""
Linear layer takes only one multi-dimensional tensor as input with the
shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
number of additional dimensions. It multiplies input tensor with the weight
(a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
an output tensor of shape :math:`[batch\_size, *, out\_features]` .
If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
shape :math:`[out\_features]` ) will be created and added to the output.
Parameters:
in_features (int): The number of input units.
out_features (int): The number of output units.
weight_attr (ParamAttr, optional): The attribute for the learnable
weight of this layer. The default value is None and the weight will be
initialized to zero. For detailed information, please refer to
paddle.ParamAttr.
transpose_weight (bool): Whether to transpose the `weight` Tensor before
multiplication.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
of this layer. If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
name (str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` .
- output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn import FusedLinear
x = paddle.randn([3, 4])
linear = FusedLinear(4, 5)
y = linear(x)
print(y.shape) # [3, 5]
"""
def __init__(self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
transpose_weight=False,
name=None):
super(FusedLinear, self).__init__()
if transpose_weight:
weight_shape = [out_features, in_features]
else:
weight_shape = [in_features, out_features]
dtype = self._helper.get_default_dtype()
self.weight = self.create_parameter(
shape=weight_shape, attr=weight_attr, dtype=dtype, is_bias=False)
self.bias = self.create_parameter(
shape=[out_features], attr=bias_attr, dtype=dtype, is_bias=True)
self.transpose_weight = transpose_weight
self.name = name
def forward(self, input):
return F.fused_linear(input, self.weight, self.bias,
self.transpose_weight, self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册