未验证 提交 2a3d9eca 编写于 作者: M Ming-Xu Huang 提交者: GitHub

cuBlasLt Epilogue To Fuse Linear + ReLU|GeLU (#39437)

* Added cuBlasLtHandle_t to device context.

* Added fused_gemm_epilogue op.

1. Added fused_gemm_epilogue op to leverage cuBlastLt Epilogue.
2. Support fusion Act(X*Y + bias), X'dims >=2 and Y'dims shoule be 2.
2. Act currently only be supported ReLU. (Will add GeLU in the future).

* Added UT to fused_gemm_epilogue op.

* Added LinearAct Pattern

1. Added LinearAct into graph_pattern_detector.* to define (2.)'s
pattern.
2. LinearAct is used to detect act(element_add(matmul_v2(x, w), bias)).
3. act currently only support ReLU (Will support GeLU in the future).

* Added FuseGemmEpiloguePass

1, Added FuseGemmEpiloguePass to handle nn.Linear + Act{ReLU}
fusion (GeLU will be supported in the future).
2. Only support matmul_v2 from nn.Linear.

* Added pybind to BuildStrageter.fuse_gemm_epilogue_.

* Added UT for fuse_gemm_epilogue_pass.

* GeLU support and EpilogueSingleton

1. Added GeLU support to fused_gemm_epilogue op.
2. Added EpilogueSingleton to cache auxiliary pointer.
3. Added related UTs.

* Rename cublaslt_epilogue_opto gemm_epilogue_op.*.

* Added both train and infer pattern to LinearAct.

1. Added support of fwd graph with grap_ops linking to LinearAct.
2. Added related changes to fuse_gemm_epilogue_pass for above
modification.

* Changed CUDA requirement from 11.4 to 11.6 for fuse_gemm_epilogue_pass.

* Added identity activation support to gemm_epilogue_op.

* Added Linear Fusion (matmul_v2 + ele_add)

1. Added matmul_v2 + ele_add pattern to LinearActPattern.
2. Added matmul_v2 + ele_add support to fuse_gemm_epilogue_pass.

* Rename gemm_epilogue_op.* to fused_gemm_epilogue_op.*

* Add fused_gemm_epilogue_grad op.

1. Added fused_gemm_epilogue_grad to support backward epilogue fusion.

* Add UTs to fused_gemm_epilogue_grad_op.

* Change attribute name in fused_gemm_epilogue_grad_op for clearing.

* Allow DX and DBias be dispensable to fused_gemm_epilogue_grad op.

* Added ElementwiseAdd+Matmul+Act graph pattern detection.

* Fuse backward of Linear( Act(x))

1. Added backward fusion pass to Linear( Act(x)).
2. Added backward fusion pass to Linear(x).

* Added UTs to backward fusion of Linear(Act(x)).

* Complete document of arguments to fused_gemm_epilogue_op.

* Made arguments of some functions pass by reference.

* Modify code with review comments.

1. Made arguments of some function pass by reference.
2. Removed redundant code.
3. Followed Google code style to change code.

* Made 'const' code style be consistent

* Fixed random seed of python UTs.

* Set Compiling constrains to cuBlasLt

1. Require CUDA 11.6+
2. Remove fuse_gemm_epilogue related tests when CUDA < 11.6.

* Code Reivew from Paddle

1. Changed arguments name is_first_gemm to without_x_gradient for
clearing.
2. Applied PADDLE_THROW in fused_gemm_epilogue_op.

* Remove EpilogueSingleton

1. Applied ReserveSpace to replace Epilogue for passing auxiliary
pointers between FWD and BWD.

* Fix a logical error and enhance UTs.

1. Added act op count checking in UTs.
2. Fix issue to fuse backward or ReLU(Linear(X)).
3. TODO: solve GELU fusion issues.

* Fix Linear and GeLU fusion issues.

1. Modified graph_detech_pattern to fit with both linear wiht gelu or
relu.
2. Modified data range in Uts to allow negative values.

* Removed fused_gemm_epilogue_op.h.

* Rename namespace pten to phi.

* Rename name of arguments in fused_gemm_epilogue_op

1. bias -> Bias.
2. out -> Out.
3. reserve_space -> ReserveSpace.

* Change EpiloguePassActivationCache as local variable.

1. Removed singleton in EpiloguePassActivationCache.
2. Made EpiloguePassActivationCache as an argument to each pass
functions.
上级 72964335
......@@ -293,11 +293,11 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
endforeach()
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
endforeach()
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
......
......@@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass)
fix_op_run_order_pass fuse_gemm_epilogue_pass)
if (WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -175,6 +176,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
!defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
AppendPassWithCheck(strategy_.fuse_gemm_epilogue_,
"fuse_gemm_epilogue_pass");
#endif
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
......@@ -507,3 +513,6 @@ USE_PASS(mkldnn_placement_pass);
!defined(_WIN32) && !defined(__APPLE__)
USE_PASS(fusion_group_pass);
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
USE_PASS(fuse_gemm_epilogue_pass);
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -124,6 +125,8 @@ struct BuildStrategy {
paddle::optional<bool> fuse_broadcast_ops_{paddle::none};
// replace batch_norm with sync_batch_norm.
bool sync_batch_norm_{false};
// Fuse GEMM+Epilogue via cublasLt epilogue.
bool fuse_gemm_epilogue_{false};
// mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means
......
......@@ -157,6 +157,7 @@ endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_gemm_epilogue_pass SRCS fuse_gemm_epilogue_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the ElewiseAdd and activation
*/
class Graph;
class Node;
class EpiloguePassActivationCache {
public:
EpiloguePassActivationCache() {}
EpiloguePassActivationCache(const EpiloguePassActivationCache &) = delete;
void operator=(const EpiloguePassActivationCache &) = delete;
bool HasFusedActivation(const std::string &key) const {
return fused_activation_space_map_.count(key);
}
ir::Node *GetFusedActivationSpace(const std::string &key) {
if (HasFusedActivation(key)) {
return fused_activation_space_map_.find(key)->second;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The key (%d) of EpiloguePassActivationCache does not exist.", key));
}
void InsertFusedActivation(const std::string &key, ir::Node *const value) {
if (!HasFusedActivation(key)) {
mtx.lock();
fused_activation_space_map_.insert({key, value});
mtx.unlock();
} else {
PADDLE_THROW(platform::errors::AlreadyExists(
"The key (%d) of EpiloguePassActivationCache already exist.", key));
}
}
private:
std::unordered_map<std::string, ir::Node *> fused_activation_space_map_;
std::mutex mtx;
};
class FuseGemmEpiloguePass : public FusePassBase {
public:
virtual ~FuseGemmEpiloguePass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FuseLinearFwd(ir::Graph *graph, bool is_training) const;
ir::Graph *FuseLinearActFwd(ir::Graph *graph,
const std::unordered_set<std::string> &act_types,
bool is_training, bool is_act_grad_x_from_act,
EpiloguePassActivationCache *cache) const;
ir::Graph *FuseLinearBwd(ir::Graph *graph, bool without_x_gradient) const;
ir::Graph *FuseLinearActBwd(
ir::Graph *graph, const std::unordered_set<std::string> &act_grad_types,
bool is_act_grad_x_from_act, EpiloguePassActivationCache *cache) const;
private:
bool IsGemmFromLinear_(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &w_shape,
OpDesc *matmul_v2_op) const;
const std::string GetReserveSpaceCacheKey(const std::string var_name,
int block_id) const {
return std::to_string(block_id) + var_name;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1461,31 +1461,6 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
return bn_grad;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {
auto *ele_y_var = pattern->NewNode(ele_y_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto *ele_out_var = pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var});
act->LinksFrom({ele_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
paddle::framework::ir::PDNode *d_act_out_var,
std::unordered_set<std::string> act_types) {
......@@ -1526,6 +1501,159 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {
auto *ele_y_var = pattern->NewNode(ele_y_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto *ele_out_var = pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var});
act->LinksFrom({ele_out_var}).LinksTo({act_out_var});
return act_out_var;
}
PDNode *patterns::LinearAct::operator()(
paddle::framework::ir::PDNode *linear_x_var,
const std::unordered_set<std::string> &act_types, bool with_grad_link,
bool is_act_grad_x_from_act) {
auto *matmul_w_var =
pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y");
auto *matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2");
auto *matmul_out_var = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out");
matmul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", "X");
auto *ele_bias_var = pattern->NewNode(ele_bias_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto *ele_out_var = pattern->NewNode(elewise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out");
matmul->LinksFrom({linear_x_var, matmul_w_var}).LinksTo({matmul_out_var});
ele_add->LinksFrom({matmul_out_var, ele_bias_var}).LinksTo({ele_out_var});
if (with_grad_link) {
matmul_out_var->assert_is_op_input("elementwise_add_grad", "X");
auto *elementwise_add_grad_op = pattern->NewNode("elementwise_add_grad")
->assert_is_op("elementwise_add_grad");
elementwise_add_grad_op->LinksFrom({matmul_out_var});
}
if (act_types.size() > 0) {
ele_out_var->AsIntermediate()->assert_is_ops_input(act_types);
auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types);
auto *act_out_var = pattern->NewNode(act_out_repr())
->assert_is_ops_output(act_types, "Out");
act->LinksFrom({ele_out_var}).LinksTo({act_out_var});
if (with_grad_link && !is_act_grad_x_from_act) {
std::unordered_set<std::string> act_grad_types;
for (const auto &act : act_types) {
std::string act_grad(act);
act_grad.append("_grad");
act_grad_types.insert(act_grad);
}
ele_out_var->assert_is_ops_input(act_grad_types, "X");
auto *act_grad_op =
pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types);
act_grad_op->LinksFrom({ele_out_var});
}
return act_out_var;
}
return ele_out_var;
}
PDNode *patterns::ElewiseAddMatmulAct::operator()(
paddle::framework::ir::PDNode *dout_var,
const std::unordered_set<std::string> &act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act) {
auto *ele_grad_bias_var =
pattern->NewNode(ele_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
auto *ele_add_grad = pattern->NewNode(ele_add_grad_repr())
->assert_is_op("elementwise_add_grad");
auto *ele_grad_dx_var =
pattern->NewNode(ele_grad_dx_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("X"));
auto *ele_grad_dbias_var =
pattern->NewNode(ele_grad_dbias_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"));
ele_add_grad->LinksFrom({dout_var, ele_grad_bias_var})
.LinksTo({ele_grad_dx_var, ele_grad_dbias_var});
ele_grad_dx_var->AsIntermediate()->assert_is_op_input("matmul_v2_grad",
GradVarName("Out"));
auto *matmul_grad_x_var = pattern->NewNode(matmul_grad_x_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto *matmul_grad_w_var = pattern->NewNode(matmul_grad_w_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto *matmul_grad =
pattern->NewNode(matmul_grad_repr())->assert_is_op("matmul_v2_grad");
auto *matmul_grad_dx_var =
pattern->NewNode(matmul_grad_dx_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("X"));
auto *matmul_grad_dw_var =
pattern->NewNode(matmul_grad_dw_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("Y"));
matmul_grad->LinksFrom(
{ele_grad_dx_var, matmul_grad_x_var, matmul_grad_w_var});
if (without_x_gradient) {
matmul_grad->LinksTo({matmul_grad_dw_var});
} else {
matmul_grad->LinksTo({matmul_grad_dx_var, matmul_grad_dw_var});
}
if (!without_x_gradient && act_grad_types.size() > 0) {
matmul_grad_dx_var->AsIntermediate()->assert_is_ops_input(
act_grad_types, GradVarName("Out"));
auto *act_grad =
pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types);
auto *act_grad_dx_var =
pattern->NewNode(act_grad_dx_repr())
->assert_is_ops_output(act_grad_types, GradVarName("X"));
auto *act_grad_x_var = matmul_grad_x_var;
if (!is_act_grad_x_from_act) {
auto *ele_out_var = pattern->NewNode(ele_out_repr())
->assert_is_ops_input(act_grad_types, "X");
act_grad_x_var = ele_out_var;
}
act_grad->LinksFrom({matmul_grad_dx_var, act_grad_x_var})
.LinksTo({act_grad_dx_var});
return act_grad;
}
return matmul_grad;
}
// conv_type: conv2d, conv3d, conv2d_transpose
PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input, std::string conv_type) {
......
......@@ -863,6 +863,65 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
PATTERN_DECL_NODE(ele_y);
};
// The following patterns are used to fuse linear and act (ReLu or GeLU)
// formula: act(F.linear(x))
// op: matmul_v2 + elementwise_add + act
// named nodes: matmul, elementwise_add, act
// matmul_w, matmul_out
// ele_bias, elewise_add_out, act_out
struct LinearAct : public PatternBase {
LinearAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "linear_act") {}
PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_types,
bool with_grad_link, bool is_act_grad_x_from_act);
// declare operator node's name
PATTERN_DECL_NODE(matmul);
PATTERN_DECL_NODE(ele_add);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(act_grad);
// declare variable node's name
PATTERN_DECL_NODE(matmul_w);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(elewise_add_out);
PATTERN_DECL_NODE(ele_bias);
PATTERN_DECL_NODE(act_out);
};
// The following patterns are used to fuse linear_grad and act_grad (ReLu or
// GeLU)
// formula: the backward of F.linear( act(x) )
// op: elementwise_add_grad + matmul_v2_grad + act_grad
// named nodes: ele_add_grad, matmul_grad, act_grad
// ele_grad_bias, ele_grad_dx, ele_grad_dbias
// matmul_grad_x, matmul_grad_dx, matmul_grad_dx
// matmul_grad_dw, act_grad_dx
struct ElewiseAddMatmulAct : public PatternBase {
ElewiseAddMatmulAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elewiseadd_matmul_act") {}
PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act);
// declare operator node's name
PATTERN_DECL_NODE(ele_add_grad);
PATTERN_DECL_NODE(matmul_grad);
PATTERN_DECL_NODE(act_grad);
// declare variable node's name
PATTERN_DECL_NODE(ele_out);
PATTERN_DECL_NODE(ele_grad_bias);
PATTERN_DECL_NODE(ele_grad_dx);
PATTERN_DECL_NODE(ele_grad_dbias);
PATTERN_DECL_NODE(matmul_grad_x);
PATTERN_DECL_NODE(matmul_grad_w);
PATTERN_DECL_NODE(matmul_grad_dx);
PATTERN_DECL_NODE(matmul_grad_dw);
PATTERN_DECL_NODE(act_grad_dx);
};
// Conv with Elementwise_add as bias
// op: conv + elementwise_add
// named nodes:
......
......@@ -19,7 +19,8 @@ register_operators(EXCLUDES
fused_attention_op
fused_transformer_op
fused_feedforward_op
resnet_unit_op)
resnet_unit_op
fused_gemm_epilogue_op)
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
......@@ -79,4 +80,8 @@ if (WITH_GPU OR WITH_ROCM)
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif()
if (CUDA_VERSION GREATER_EQUAL 11.6)
op_library(fused_gemm_epilogue_op)
endif()
endif()
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueOp");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Output", "Bias",
"FusedGemmEpilogueOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FusedGemmEpilogueOp");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto bias_dims = ctx->GetInputDim("Bias");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
PADDLE_ENFORCE_EQ(
y_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input tensor Y's dimension of FusedGemmEpilogueOp "
" should be 2, but got %d.",
y_dims.size()));
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input tensor X's dimension of FusedGemmEpilogueOp "
" should be >= 2, but got %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
bias_dims.size(), 1,
platform::errors::InvalidArgument(
"The Input tensor bias's dimension of FusedGemmEpilogueOp "
" should be == 1, but got %d.",
bias_dims.size()));
PADDLE_ENFORCE_EQ(bias_dims[0], trans_y ? y_dims[0] : y_dims[1],
platform::errors::InvalidArgument(
"The Input tensor bias's dimension 0"
" should be == Y[-1], but got bias's shape = [%s] "
"and Y's shape = [%s]",
bias_dims, y_dims));
auto x_mat_dims =
phi::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1);
int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1];
int K_from_y = trans_y ? y_dims[1] : y_dims[0];
PADDLE_ENFORCE_EQ(
K_from_x, K_from_y,
platform::errors::InvalidArgument(
"The last dimension of X should be equal with Y's first dimension."
"But received X[-1] = [%d], Y[0] = [%d].",
K_from_x, K_from_y));
auto activation = ctx->Attrs().Get<std::string>("activation");
if ((activation != "relu") && (activation != "gelu") &&
(activation != "none")) {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
if (activation == "none" && ctx->HasOutput("ReserveSpace")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The ReserveSpace would not be used when activation = \"none\""));
}
// cublasLt's restriction for auxiliary.
if (ctx->HasOutput("ReserveSpace") && activation != "none") {
int min_size_of_n = activation == "relu" ? 128 : 8;
int N_size = trans_y ? y_dims[0] : y_dims[1];
PADDLE_ENFORCE_EQ(N_size % min_size_of_n, 0,
platform::errors::InvalidArgument(
"The output dimension N (X(MxK) * Y(KxN) = C(MxN)) "
"should be multiple of %d when auxiliary_key given "
"and activation=%s, but got N = %d.",
min_size_of_n, activation, N_size));
}
std::vector<int64_t> out_dims;
out_dims.reserve(static_cast<size_t>(x_dims.size()));
if (trans_x) {
for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]);
} else {
for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]);
}
if (trans_y) {
out_dims.push_back(y_dims[0]);
} else {
out_dims.push_back(y_dims[1]);
}
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
// Note (Ming Huang): Reserve space of relu is a bit-mask,
// which cannot pass nan_and_inf checking if shape is set.
if (activation == "gelu" && ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
class FusedGemmEpilogueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor X of Out = Act((X * Y) + Bias).");
AddInput("Y", "The input tensor Y of Out = Act((X * Y) + Bias).");
AddInput("Bias", "The input tensor bias of Out = Act((X * Y) + Bias).");
AddOutput("Out", "The output tensor Out of Out = Act((X * Y) + Bias).");
AddOutput("ReserveSpace",
R"DOC(Reserve GPU space to place
auxiliary data pointer. It is used to pass auxiliary data pointer
for fused_gemm_epilogue op. If not given (empty string), the
auxiliary mode would not be enable.)DOC")
.AsDispensable()
.AsExtra();
AddAttr<bool>(
"trans_x",
R"DOC((bool, default false), Whether to transpose input tensor X
or not. The input tensor X coulbe be more than two dimension. When
set trans_x=true, it would fully reverse X. For instant: X with shpae
[d0, d1, d2, d3] -> [d3, d2, d1, d0].)DOC")
.SetDefault(false);
AddAttr<bool>(
"trans_y",
R"DOC((bool, default false), Whether to transpose input tensor Y
or not. The input tensor Y should be two dimension. When
set trans_y=true, it would transpose Y. For instant: Y with shpae
[d0, d1] -> [d1, d0].)DOC")
.SetDefault(false);
AddAttr<std::string>(
"activation",
R"DOC((string, default none), The activation function. It could be
one of {none, relu, gelu}. When none is given, Act would be null
operations)DOC")
.SetDefault("none");
AddComment(R"DOC(
FusedGemmEpilogue Operator
This operator is used to perform Activeation(Elementwise_add(Matmul(X, Y), bias)).
It is equal to paddle.nn.Linear + Activation (None, ReLU or GeLU).
Note:
X could be more than two dimension and would be flatten to 2D for computing.
X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3]
)DOC");
}
};
class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("DOut"), "Input", "DOut",
"FusedGemmEpilogueGradOp");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueGradOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueGradOp");
OP_INOUT_CHECK(ctx->HasOutput("DY"), "Output", "DY", "FusedGemmEpilogueOp");
auto dout_dims = ctx->GetInputDim("DOut");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(
dout_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input tensor DOut's dimension of FusedGemmEpilogueGradOp "
" should be >= 2, but got %d.",
dout_dims.size()));
PADDLE_ENFORCE_EQ(
y_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input tensor Y's dimension of FusedGemmEpilogueGradOp "
" should be 2, but got %d.",
y_dims.size()));
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input tensor X's dimension of FusedGemmEpilogueGradOp "
" should be >= 2, but got %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
dout_dims.size(), x_dims.size(),
platform::errors::InvalidArgument(
"The Input tensor DOut's and X's dimension of "
"FusedGemmEpilogueGradOp "
" should be the same, but got DOut's dim = %d and X's = %d.",
dout_dims.size(), x_dims.size()));
auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ(
dout_mat_dims[1], y_dims[1],
platform::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1], y_dims[1]));
PADDLE_ENFORCE_EQ(
dout_mat_dims[0], x_mat_dims[0],
platform::errors::InvalidArgument(
"The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d].",
dout_mat_dims[0], x_mat_dims[0]));
auto activation_grad = ctx->Attrs().Get<std::string>("activation_grad");
if ((activation_grad != "relu_grad") && (activation_grad != "gelu_grad") &&
(activation_grad != "none")) {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation_grad));
}
if (activation_grad != "none" && !ctx->HasInput("ReserveSpace")) {
PADDLE_ENFORCE_EQ(true, false,
platform::errors::InvalidArgument(
"The ReserveSpace should not be empty. "
"when activation_grad == {relu_grad, gelu_grad}."));
}
if (ctx->HasOutput("DX")) {
std::vector<int64_t> dx_dims;
dx_dims.reserve(static_cast<size_t>(x_dims.size()));
for (int i = 0; i < x_dims.size(); ++i) {
dx_dims.push_back(x_dims[i]);
}
ctx->SetOutputDim("DX", phi::make_ddim(dx_dims));
}
std::vector<int64_t> dy_dims(y_dims.Get(), y_dims.Get() + y_dims.size());
ctx->SetOutputDim("DY", phi::make_ddim(dy_dims));
if (ctx->HasOutput("DBias")) {
std::vector<int64_t> dbias_dims;
dbias_dims.push_back(y_dims[1]);
ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("DOut",
"The input grad tensor to Out of Out = (Act(X) * Y) + bias");
AddInput("X", "The input tensor X of Out = (Act(X) * Y) + bias");
AddInput("Y", "The input tensor Y of Out = (Act(X) * Y) + bias");
AddInput("ReserveSpace",
R"DOC(A GPU space to fetch
auxiliary data pointer. It is used to pass auxiliary data pointer
for fused_gemm_epilogue_grad op. If not given (empty string), the
auxiliary mode would not be enable.)DOC")
.AsDispensable();
AddOutput("DX", "The output grad tensor to X of Out = (Act(X) * Y) + bias.")
.AsDispensable();
AddOutput("DY",
"The output grad tensor to Y of Out = (Act(X) * Y) + bias.");
AddOutput("DBias",
"The output grad tensor to bias of Out = (Act(X) * Y) + bias.")
.AsDispensable();
AddAttr<std::string>(
"activation_grad",
R"DOC((string, default none), The backward activation function. It could be
one of {none, relu_grad, gelu_grad}. When none is given, The backward Act would
be null operations)DOC")
.SetDefault("none");
AddComment(R"DOC(
FusedGemmEpilogueGrad Operator
This operator is used to perform backward of Elementwise_add(Matmul(Activeation(X), Y), bias).
It is equal to Activation (None, ReLU or GeLU) + paddle.nn.Linear.
Note:
X could be more than two dimension and would be flatten to 2D for computing.
X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_gemm_epilogue, ops::FusedGemmEpilogueOp,
ops::FusedGemmEpilogueOpMaker)
REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp,
ops::FusedGemmEpilogueGradOpMaker)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* bias = ctx.Input<Tensor>("Bias");
Tensor* out = ctx.Output<Tensor>("Out");
Tensor* reserve_space = ctx.Output<Tensor>("ReserveSpace");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation");
bool enable_auxiliary = reserve_space == nullptr ? false : true;
out->mutable_data<T>(ctx.GetPlace());
auto* out_data = out->data<T>();
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F;
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx,
sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy,
sizeof(transy)));
cublasLtEpilogue_t epiloque_func =
get_epilogue_type_(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func,
sizeof(epiloque_func)));
const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_data,
sizeof(bias_data)));
if (enable_auxiliary && activation != "none") {
size_t reserve_space_size = 0;
if (activation == "relu") {
// Count in bits.
reserve_space_size = phi::product(out->dims()) / 8;
} else {
reserve_space_size = phi::product(out->dims()) * sizeof(T);
}
reserve_space->mutable_data(ctx.GetPlace(), out->type(),
reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N,
sizeof(N)));
}
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x)
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, M, K, M));
else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
if (trans_y)
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, K, N, K));
else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx, workspace_size);
double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, operation_desc, alpha, y->data<T>(), y_desc, x->data<T>(),
x_desc, beta, out_data, out_desc, out_data, out_desc, algo,
workspace->ptr(), workspace_size, stream));
}
private:
static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation,
bool enable_auxiliary) {
if (activation == "relu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
: CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == "gelu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
: CUBLASLT_EPILOGUE_GELU_BIAS;
} else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS;
} else {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
};
template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const Tensor* dout = ctx.Input<Tensor>("DOut");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* reserve_space = ctx.Input<Tensor>("ReserveSpace");
Tensor* dx = ctx.Output<Tensor>("DX");
Tensor* dy = ctx.Output<Tensor>("DY");
Tensor* dbias = ctx.Output<Tensor>("DBias");
std::string activation_grad = ctx.Attr<std::string>("activation_grad");
auto dout_mat_dims =
phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1);
int64_t M = x_mat_dims[0];
int64_t K = y->dims()[0];
int64_t N = y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F;
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
cublasOperation_t trans_dout = CUBLAS_OP_N;
cublasLtMatrixLayout_t dout_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dout_desc, mat_type, N, M, N));
if (dx) {
cublasLtMatmulDesc_t dx_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
cublasOperation_t trans_y = CUBLAS_OP_T;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y,
sizeof(trans_y)));
cublasLtEpilogue_t epiloque_func_for_dx =
get_epilogue_type_(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dx, sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") {
auto* aux_data = reserve_space->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N,
sizeof(N)));
}
cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, K, M, K));
memory::allocation::AllocationPtr dx_workspace =
memory::Alloc(dev_ctx, workspace_size);
dx->mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dx_operation_desc, alpha, y->data<T>(), y_desc,
dout->data<T>(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc,
algo, dx_workspace->ptr(), workspace_size, stream));
}
if (dy) {
cublasLtMatmulDesc_t dy_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
cublasOperation_t trans_x = CUBLAS_OP_T;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x,
sizeof(trans_x)));
cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr
? CUBLASLT_EPILOGUE_DEFAULT
: CUBLASLT_EPILOGUE_BGRADA;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dy, sizeof(epiloque_func_for_dy)));
if (dbias) {
dbias->mutable_data<T>(ctx.GetPlace());
auto* dbias_data = dbias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&dbias_data, sizeof(dbias_data)));
}
cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, N, K, N));
memory::allocation::AllocationPtr dy_workspace =
memory::Alloc(dev_ctx, workspace_size);
dy->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dy_operation_desc, alpha, dout->data<T>(), dout_desc,
x->data<T>(), x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, algo,
dy_workspace->ptr(), workspace_size, stream));
}
}
private:
static cublasLtEpilogue_t get_epilogue_type_(
const std::string& activation_grad) {
if (activation_grad == "relu_grad") {
return CUBLASLT_EPILOGUE_DRELU;
} else if (activation_grad == "gelu_grad") {
return CUBLASLT_EPILOGUE_DGELU;
} else if (activation_grad == "none") {
return CUBLASLT_EPILOGUE_DEFAULT;
} else {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::InvalidArgument(
"The activation_grad attribute of fused_gemm_epilogue op should "
"be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation_grad=%s.",
activation_grad));
}
}
};
} // namespace operators
} // namespace paddle
#if CUDA_VERSION >= 11060
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fused_gemm_epilogue,
ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext, float>,
ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
#endif
......@@ -19,6 +19,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
......@@ -110,5 +111,28 @@ class CublasHandleHolder {
mutable std::mutex mtx_;
};
class CublasLtHandleHolder {
public:
CublasLtHandleHolder() {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasLtCreate(&handle_));
}
const cublasLtHandle_t& GetCublasLtHandle() const { return handle_; }
~CublasLtHandleHolder() PADDLE_MAY_THROW {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasLtDestroy(handle_));
}
inline void Call(const std::function<void(blasLtHandle_t)>& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasLtHandleHolder);
cublasLtHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -24,6 +25,7 @@
#else
#include <cuda_runtime.h>
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#endif
......@@ -70,6 +72,10 @@ DECLARE_TYPE_FOR_GPU(dnnHandle_t, cudnnHandle_t, miopenHandle_t);
DECLARE_TYPE_FOR_GPU(blasHandle_t, cublasHandle_t, rocblas_handle);
// TODO(Ming Huang): Since there is no blasLt handler,
// use rocblas_handle for workround.
DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle);
using CUDAGraphID = unsigned long long; // NOLINT
#undef DECLARE_TYPE_FOR_GPU
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -465,6 +467,9 @@ CUDAContext::CUDAContext(const CUDAPlace& place,
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSparseContext();
InitCuSolverContext();
#endif
......@@ -476,6 +481,9 @@ void CUDAContext::SetStream(gpuStream_t stream) {
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
DestoryCuBlasLtContext();
#endif
DestoryCuSolverContext();
#endif
......@@ -485,6 +493,9 @@ void CUDAContext::SetStream(gpuStream_t stream) {
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSolverContext();
#endif
}
......@@ -495,6 +506,9 @@ CUDAContext::~CUDAContext() {
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
DestoryCuSparseContext();
DestoryCuSolverContext();
#endif
......@@ -551,6 +565,14 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
}
return phi::GPUContext::cublas_handle();
}
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasLtHandle()->GetCublasLtHandle();
}
return phi::GPUContext::cublaslt_handle();
}
#endif
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
if (thread_ctx_.count(this)) {
return context()->CusparseHandle()->GetCusparseHandle();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -29,6 +31,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/dynload/cusparse.h"
......@@ -332,6 +335,12 @@ class CUDAContext {
}
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
const std::unique_ptr<CublasLtHandleHolder>& CublasLtHandle() const {
return cublaslt_handle_;
}
#endif
const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
return cusparse_handle_;
}
......@@ -348,6 +357,14 @@ class CUDAContext {
}
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
/*! \brief Call cublasLt function safely. */
inline void CublasLtCall(
const std::function<void(blasLtHandle_t)>& callback) const {
cublaslt_handle_->Call(callback);
}
#endif
/*! \brief Call cusparse function safely. */
inline void CusparseCall(
const std::function<void(phi::sparseHandle_t)>& callback) const {
......@@ -394,6 +411,12 @@ class CUDAContext {
#endif
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
void InitCuBlasLtContext() {
cublaslt_handle_.reset(new CublasLtHandleHolder());
}
#endif
void InitCuSparseContext() {
cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
}
......@@ -472,6 +495,10 @@ class CUDAContext {
}
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
void DestoryCuBlasLtContext() { cublaslt_handle_.reset(); }
#endif
void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif
......@@ -497,6 +524,9 @@ class CUDAContext {
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
std::unique_ptr<CublasLtHandleHolder> cublaslt_handle_;
#endif
cusolverDnHandle_t cusolver_dn_handle_;
std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
#endif
......@@ -559,6 +589,7 @@ class CUDADeviceContext : public phi::GPUContext {
rocblas_handle cublas_handle() const;
#else
cublasHandle_t cublas_handle() const;
cublasLtHandle_t cublaslt_handle() const;
cusparseHandle_t cusparse_handle() const;
#endif
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -3440,6 +3441,31 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = static.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = True
)DOC")
.def_property(
"fuse_gemm_epilogue",
[](const BuildStrategy &self) { return self.fuse_gemm_epilogue_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fuse_gemm_epilogue_ = b;
},
R"DOC((bool, optional): fuse_gemm_epilogue indicate whether
to fuse matmul_op, elemenewist_add_op and activation_op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True
)DOC")
.def_property(
"fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -56,6 +57,9 @@ using cudnnFusedOpsPlan_t = struct cudnnFusedOpsPlanStruct *;
// Forward declaration of cuBLAS types.
using cublasHandle_t = struct cublasContext *;
// Forward declaration of cuBLASLt types.
using cublasLtHandle_t = struct cublasLtContext *;
// Forward declaration of cuSOLVER types.
using cusolverDnHandle_t = struct cusolverDnContext *;
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -171,6 +172,7 @@ struct GPUContext::Impl {
InitStream();
InitEigenDevice();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
......@@ -183,6 +185,7 @@ struct GPUContext::Impl {
InitGpuProperties();
InitStream();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
......@@ -212,6 +215,7 @@ struct GPUContext::Impl {
}
#endif
DestroyInternalBlasHandle();
DestroyInternalBlasLtHandle();
DestoryInternalStream();
}
......@@ -418,6 +422,25 @@ struct GPUContext::Impl {
void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }
void InitBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtCreate(&blaslt_handle_);
#endif
}
void DestroyInternalBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtDestroy(blaslt_handle_);
#endif
}
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
blasLtHandle_t GetBlasLtHandle() const {
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_;
}
void InitDNNHandle() {
if (phi::dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
......@@ -679,6 +702,7 @@ struct GPUContext::Impl {
blasHandle_t blas_handle_{nullptr};
blasHandle_t blas_tensor_core_handle_{nullptr};
blasHandle_t blas_tf32_tensor_core_handle_{nullptr};
blasLtHandle_t blaslt_handle_{nullptr};
dnnHandle_t dnn_handle_{nullptr};
solverHandle_t solver_handle_{nullptr};
sparseHandle_t sparse_handle_{nullptr};
......@@ -725,6 +749,10 @@ blasHandle_t GPUContext::cublas_handle() const {
return impl_->GetBlasHandle();
}
blasLtHandle_t GPUContext::cublaslt_handle() const {
return impl_->GetBlasLtHandle();
}
solverHandle_t GPUContext::cusolver_dn_handle() const {
return impl_->GetSolverHandle();
}
......@@ -815,6 +843,10 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_->SetBlasHandle(blas);
}
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
impl_->SetBlasLtHandle(blaslt);
}
void GPUContext::SetDnnHandle(dnnHandle_t handle) {
impl_->SetDnnHandle(handle);
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -93,6 +94,9 @@ class GPUContext : public DeviceContext {
/*! \brief Return cublas handle in the device context. */
blasHandle_t cublas_handle() const;
/*! \brief Return cublasLt handle in the device context. */
blasLtHandle_t cublaslt_handle() const;
/*! \brief Return cusolver handle in the device context. */
solverHandle_t cusolver_dn_handle() const;
......@@ -193,6 +197,8 @@ class GPUContext : public DeviceContext {
void SetBlasHandle(blasHandle_t);
void SetBlasLtHandle(blasLtHandle_t);
void SetDnnHandle(dnnHandle_t);
void SetSolverHandle(solverHandle_t);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -59,6 +60,10 @@ DECLARE_TYPE_FOR_GPU(dnnHandle_t, cudnnHandle_t, miopenHandle_t);
DECLARE_TYPE_FOR_GPU(blasHandle_t, cublasHandle_t, rocblas_handle);
// TODO(Ming Huang): Since there is no blasLt handler,
// use rocblas_handle for workround.
DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle);
DECLARE_TYPE_FOR_GPU(solverHandle_t, cusolverDnHandle_t, rocsolver_handle);
DECLARE_TYPE_FOR_GPU(sparseHandle_t, cusparseHandle_t, rocsparse_handle);
......
......@@ -125,6 +125,17 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
endif()
if (WITH_GPU)
if (CUDA_VERSION LESS 11.6)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
endif()
endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cases for role makers."""
from __future__ import print_function
import paddle
import os
import unittest
import numpy as np
import paddle.fluid.core as core
def compare(ref, res, atol, rtol):
ref = np.array(ref).flatten()
res = np.array(res).flatten()
tmp_ref = ref.astype(np.float)
tol = atol + rtol * abs(tmp_ref)
diff = abs(res - ref)
indices = np.transpose(np.where(diff > tol))
if len(indices) == 0:
return True
return False
def verify_node_count(graph, node_name, target_count):
count = 0
for node in graph.nodes():
if node.name() == node_name:
count += 1
return count == target_count
class MultiFCLayer(paddle.nn.Layer):
def __init__(self, hidden, Activation):
super(MultiFCLayer, self).__init__()
self.linear1 = paddle.nn.Linear(hidden, hidden)
self.linear2 = paddle.nn.Linear(hidden, hidden)
self.linear3 = paddle.nn.Linear(hidden, hidden)
self.relu1 = Activation()
self.relu2 = Activation()
self.relu3 = Activation()
def forward(self, x, matmul_y, ele_y):
output = self.linear1(x)
output = self.relu1(output)
output = self.linear2(output)
output1 = paddle.matmul(output, matmul_y)
output = self.linear3(output)
output = self.relu2(output)
output = paddle.matmul(output, matmul_y)
output = paddle.add(output, ele_y)
output = self.relu3(output)
output = paddle.add(output, output1)
return output
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
def setUp(self):
self.batch = 64
self.seqlen = 128
self.hidden = 768
paddle.enable_static()
self.main_prog = paddle.static.Program()
self.startup_prog = paddle.static.Program()
with paddle.static.program_guard(self.main_prog, self.startup_prog):
data = paddle.static.data(
name="_data",
shape=[-1, self.seqlen, self.hidden],
dtype='float32')
matmul_y = paddle.static.data(
name="_matmul_y",
shape=[1, self.hidden, self.hidden],
dtype='float32')
ele_y = paddle.static.data(
name="_ele_y", shape=[self.hidden, ], dtype='float32')
multi_layer = MultiFCLayer(self.hidden, self._get_act_type()[0])
with paddle.static.amp.fp16_guard():
out = multi_layer(data, matmul_y, ele_y)
self.loss = paddle.mean(out)
self.data_arr = np.random.random(
(self.batch, self.seqlen, self.hidden)).astype("float32") - 0.5
self.matmul_y_arr = np.random.random(
(1, self.hidden, self.hidden)).astype("float32") - 0.5
self.ele_y_arr = np.random.random(
(self.hidden, )).astype("float32") - 0.5
self.place = paddle.CUDAPlace(0)
self.exe = paddle.static.Executor(self.place)
self.exe.run(self.startup_prog)
self._pre_test_hooks()
self.feed = {
"_data": self.data_arr,
"_matmul_y": self.matmul_y_arr,
"_ele_y": self.ele_y_arr
}
self.reference = self.exe.run(self.main_prog,
feed=self.feed,
fetch_list=[self.loss.name])
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
def _test_output(self):
build_strategy = paddle.static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True
program = paddle.static.CompiledProgram(self.main_prog)
program = program.with_data_parallel(
loss_name=self.loss.name,
build_strategy=build_strategy,
places=paddle.static.cuda_places())
result = self.exe.run(program,
feed=self.feed,
fetch_list=[self.loss.name])
self.assertTrue(
compare(self.reference, result, self.atol, self.rtol),
"[{}] outputs are miss-matched.".format(type(self).__name__))
self.assertTrue(
verify_node_count(program._graph, "fused_gemm_epilogue", 3),
"[{}] The number of fused_gemm_epilogue is miss-matched in the computing graph.".
format(type(self).__name__))
act_fwd_name = self._get_act_type()[1]
self.assertTrue(
verify_node_count(program._graph, act_fwd_name, 1),
"[{}] The number of {} is miss-matched in the computing graph.".
format(type(self).__name__, act_fwd_name))
def _pre_test_hooks(self):
self.atol = 1e-4
self.rtol = 1e-3
def _get_act_type(self):
return paddle.nn.ReLU, "relu"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReluFWDFP32(TestFuseGemmEpilogueFWDBase):
def _pre_test_hooks(self):
self.atol = 1e-3
self.rtol = 1e-2
def _get_act_type(self):
return paddle.nn.ReLU, "relu"
def test_output(self):
self._test_output()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32):
def _pre_test_hooks(self):
self.atol = 1e-3
self.rtol = 1e-2
fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog)
paddle.static.amp.cast_parameters_to_fp16(
self.place, self.main_prog, to_fp16_var_names=fp16_var_list)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeluFWDFP32(TestFuseGemmEpilogueFWDBase):
def _pre_test_hooks(self):
self.atol = 1e-4
self.rtol = 1e-3
def _get_act_type(self):
return paddle.nn.GELU, "gelu"
def test_output(self):
self._test_output()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
def _pre_test_hooks(self):
self.atol = 1e-3
self.rtol = 1e-2
fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog)
paddle.static.amp.cast_parameters_to_fp16(
self.place, self.main_prog, to_fp16_var_names=fp16_var_list)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
def setUp(self):
self.batch = 64
self.seqlen = 128
self.hidden = 768
paddle.enable_static()
self.main_prog = paddle.static.Program()
self.startup_prog = paddle.static.Program()
with paddle.static.program_guard(self.main_prog, self.startup_prog):
data = paddle.static.data(
name="_data",
shape=[-1, self.seqlen, self.hidden],
dtype='float32')
matmul_y = paddle.static.data(
name="_matmul_y",
shape=[1, self.hidden, self.hidden],
dtype='float32')
ele_y = paddle.static.data(
name="_ele_y", shape=[self.hidden, ], dtype='float32')
multi_layer = MultiFCLayer(self.hidden, self._get_act_type()[0])
with paddle.static.amp.fp16_guard():
out = multi_layer(data, matmul_y, ele_y)
self.loss = paddle.mean(out)
paddle.static.append_backward(loss=self.loss)
self.data_arr = np.random.random(
(self.batch, self.seqlen, self.hidden)).astype("float32") - 0.5
self.matmul_y_arr = np.random.random(
(1, self.hidden, self.hidden)).astype("float32") - 0.5
self.ele_y_arr = np.random.random(
(self.hidden, )).astype("float32") - 0.5
self.place = paddle.CUDAPlace(0)
self.exe = paddle.static.Executor(self.place)
self.exe.run(self.startup_prog)
self._pre_test_hooks()
self.feed = {
"_data": self.data_arr,
"_matmul_y": self.matmul_y_arr,
"_ele_y": self.ele_y_arr
}
self.fetch = [
self.loss.name,
'{}.w_0@GRAD'.format(multi_layer.linear1.full_name()),
'{}.b_0@GRAD'.format(multi_layer.linear1.full_name()),
'{}.w_0@GRAD'.format(multi_layer.linear2.full_name()),
'{}.b_0@GRAD'.format(multi_layer.linear2.full_name()),
'{}.w_0@GRAD'.format(multi_layer.linear3.full_name()),
'{}.b_0@GRAD'.format(multi_layer.linear3.full_name())
]
self.outs_ref = self.exe.run(self.main_prog,
feed=self.feed,
fetch_list=self.fetch)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
def _test_output(self):
build_strategy = paddle.static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True
program = paddle.static.CompiledProgram(self.main_prog)
program = program.with_data_parallel(
loss_name=self.loss.name,
build_strategy=build_strategy,
places=paddle.static.cuda_places())
outs_res = self.exe.run(program, feed=self.feed, fetch_list=self.fetch)
for ref, res in zip(self.outs_ref, outs_res):
self.assertTrue(
compare(ref, res, self.atol, self.rtol),
"[{}] output is miss-matched.".format(type(self).__name__))
self.assertTrue(
verify_node_count(program._graph, "fused_gemm_epilogue", 3),
"[{}] The number of fused_gemm_epilogue is miss-matched in the computing graph.".
format(type(self).__name__))
self.assertTrue(
verify_node_count(program._graph, "fused_gemm_epilogue_grad", 3),
"[{}] The number of fused_gemm_epilogue_grad is miss-matched in the computing graph.".
format(type(self).__name__))
_, act_fwd_name, act_bwd_name = self._get_act_type()
self.assertTrue(
verify_node_count(program._graph, act_fwd_name, 1),
"[{}] The number of {} is miss-matched in the computing graph.".
format(type(self).__name__, act_fwd_name))
self.assertTrue(
verify_node_count(program._graph, act_bwd_name, 2),
"[{}] The number of {} is miss-matched in the computing graph.".
format(type(self).__name__, act_bwd_name))
def _pre_test_hooks(self):
self.atol = 1e-4
self.rtol = 1e-3
def _get_act_type(self):
return paddle.nn.ReLU, "relu", "relu_grad"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
def _pre_test_hooks(self):
self.atol = 1e-4
self.rtol = 1e-3
def _get_act_type(self):
return paddle.nn.ReLU, "relu", "relu_grad"
def test_output(self):
self._test_output()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
def _pre_test_hooks(self):
self.atol = 1e-3
self.rtol = 1e-2
fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog)
paddle.static.amp.cast_parameters_to_fp16(
self.place, self.main_prog, to_fp16_var_names=fp16_var_list)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase):
def _pre_test_hooks(self):
self.atol = 5e-4
self.rtol = 1e-3
def _get_act_type(self):
return paddle.nn.GELU, "gelu", "gelu_grad"
def test_output(self):
self._test_output()
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
def _pre_test_hooks(self):
self.atol = 1e-3
self.rtol = 1e-2
fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog)
paddle.static.amp.cast_parameters_to_fp16(
self.place, self.main_prog, to_fp16_var_names=fp16_var_list)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
def get_outputs(DOut, X, Y):
DX = np.dot(DOut, Y.T)
DY = np.dot(X.T, DOut)
DBias = np.sum(DOut, axis=0)
return DX, DY, DBias
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYBiasFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue_grad"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
DX, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DX': DX, 'DY': DY, 'DBias': DBias}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYBiasFP32(
TestFuseGemmEpilogueGradOpDXYBiasFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYBiasFP64(
TestFuseGemmEpilogueGradOpDXYBiasFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYBiasFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue_grad"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
_, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DY': DY, 'DBias': DBias}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYBiasFP32(
TestFuseGemmEpilogueGradOpDYBiasFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYBiasFP64(
TestFuseGemmEpilogueGradOpDYBiasFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue_grad"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
_, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DY': DY}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYFP32(TestFuseGemmEpilogueGradOpDYFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDYFP64(TestFuseGemmEpilogueGradOpDYFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue_grad"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5,
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
DX, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'],
self.inputs['Y'])
self.outputs = {'DX': DX, 'DY': DY}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYFP32(TestFuseGemmEpilogueGradOpDXYFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
def gelu(x):
y_ref = 0.5 * x * (
1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
return y_ref.astype(x.dtype)
def relu(x):
mask = x > 0
return x * mask
def get_output(X, Y, bias, act):
out = np.dot(X, Y) + bias
if act == 'relu':
return relu(out)
elif act == 'gelu':
return gelu(out)
else:
return out
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'], self.inputs['Y'],
self.inputs['Bias'], 'relu')
}
self.attrs = {"activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP32(TestFuseGemmEpilogueOpReluMMFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP64(TestFuseGemmEpilogueOpReluMMFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((4, 8)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'].T, self.inputs['Y'],
self.inputs['Bias'], 'relu')
}
self.attrs = {'trans_x': True, "activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP32(TestFuseGemmEpilogueOpReluMTMFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP64(TestFuseGemmEpilogueOpReluMTMFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMTFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((128, 4)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'], self.inputs['Y'].T,
self.inputs['Bias'], 'relu')
}
self.attrs = {'trans_y': True, "activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMTFP32(TestFuseGemmEpilogueOpReluMMTFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMTFP64(TestFuseGemmEpilogueOpReluMMTFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMTFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((4, 8)).astype(self.dtype) - 0.5,
'Y': np.random.random((128, 4)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'].T, self.inputs['Y'].T,
self.inputs['Bias'], 'relu')
}
self.attrs = {'trans_x': True, 'trans_y': True, "activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMTFP32(TestFuseGemmEpilogueOpReluMTMTFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMTFP64(TestFuseGemmEpilogueOpReluMTMTFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((2, 2, 8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'].reshape(
(-1, 4)), self.inputs['Y'], self.inputs['Bias'],
'relu').reshape((2, 2, 8, 128))
}
self.attrs = {"activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP32MultiDimX(
TestFuseGemmEpilogueOpReluMMFP16MultiDimX):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMMFP64MultiDimX(
TestFuseGemmEpilogueOpReluMMFP16MultiDimX):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((4, 2, 2, 8)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.outputs = {
'Out': get_output(self.inputs['X'].reshape(
(4, -1)).T, self.inputs['Y'], self.inputs['Bias'],
'relu').reshape((2, 2, 8, 128))
}
self.attrs = {'trans_x': True, "activation": 'relu'}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP32MultiDimX(
TestFuseGemmEpilogueOpReluMTMFP16MultiDimX):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpReluMTMFP64MultiDimX(
TestFuseGemmEpilogueOpReluMTMFP16MultiDimX):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpGeluMMFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'gelu'}
self.outputs = {
'Out': get_output(self.inputs['X'], self.inputs['Y'],
self.inputs['Bias'], 'gelu')
}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpGeluMMFP32(TestFuseGemmEpilogueOpGeluMMFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpGeluMMFP64(TestFuseGemmEpilogueOpGeluMMFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpNoneMMFP16(OpTest):
def setUp(self):
self.op_type = "fused_gemm_epilogue"
self.place = core.CUDAPlace(0)
self.init_dtype_type()
self.inputs = {
'X': np.random.random((8, 4)).astype(self.dtype) - 0.5,
'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5,
'Bias': np.random.random((128, )).astype(self.dtype) - 0.5
}
self.attrs = {"activation": 'none'}
self.outputs = {
'Out': get_output(self.inputs['X'], self.inputs['Y'],
self.inputs['Bias'], 'none')
}
def init_dtype_type(self):
self.dtype = np.float16
self.atol = 1e-3
def test_check_output(self):
if self.dtype == np.float16 and not core.is_float16_supported(
self.place):
return
self.check_output_with_place(self.place, atol=self.atol)
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpNoneMMFP32(TestFuseGemmEpilogueOpNoneMMFP16):
def init_dtype_type(self):
self.dtype = np.single
self.atol = 1e-6
@skip_check_grad_ci(reason="no grap op")
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
def init_dtype_type(self):
self.dtype = np.double
self.atol = 1e-6
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -729,4 +730,6 @@ STATIC_MODE_TESTING_LIST = [
'test_lu_op',
'test_margin_cross_entropy_op',
'test_pull_gpups_sparse_op',
'test_fused_gemm_epilogue_op',
'test_fused_gemm_epilogue_grad_op',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册