提交 0ed26e12 编写于 作者: R root

support weight transpose

上级 60b86b2f
......@@ -118,6 +118,7 @@ message BuildStrategy {
optional bool fix_op_run_order = 13 [ default = false ];
optional bool allow_cuda_graph_capture = 14 [ default = false ];
optional int32 reduce_strategy = 15 [ default = 0 ];
optional bool fuse_gemm_epilogue = 16 [ default = false ];
}
message ExecutionStrategy {
......
......@@ -18,18 +18,28 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
PADDLE_DEFINE_EXPORTED_bool(enable_gemm_fwd_fusion, true, "");
namespace paddle {
namespace framework {
namespace ir {
static void GetTransposeAttrsFromOp(const OpDesc &op, bool *trans_x,
bool *trans_y) {
*trans_x = BOOST_GET_CONST(bool, op.GetAttr("trans_x"));
*trans_y = BOOST_GET_CONST(bool, op.GetAttr("trans_y"));
}
void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const {
EpiloguePassActivationCache cache;
if (FLAGS_enable_gemm_fwd_fusion) {
graph = FuseLinearActFwd(graph, {"relu", "gelu"}, false, false, &cache);
graph = FuseLinearActFwd(graph, {"relu"}, true, true, &cache);
graph = FuseLinearActFwd(graph, {"gelu"}, true, false, &cache);
graph = FuseLinearFwd(graph, false);
graph = FuseLinearFwd(graph, true);
}
graph = FuseLinearActBwd(graph, {"relu_grad"}, true, &cache);
graph = FuseLinearActBwd(graph, {"gelu_grad"}, false, &cache);
graph = FuseLinearBwd(graph, false);
......@@ -75,6 +85,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc))
return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
std::string activation = "none";
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
......@@ -85,6 +98,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node);
......@@ -154,6 +169,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
auto activation = act_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block());
fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue");
fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()});
......@@ -163,6 +181,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
fused_gemm_epilogue_op_desc.SetAttr("activation", activation);
fused_gemm_epilogue_op_desc.SetAttr("op_role",
matmul_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc);
......@@ -274,6 +294,9 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
matmul_grad_op_desc))
return;
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
std::string activation_grad = "none";
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
......@@ -292,6 +315,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
......@@ -394,6 +419,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
auto activation_grad = act_grad_op->Op()->Type();
bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block());
fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad");
fused_gemm_epilogue_grad_op_desc.SetInput("DOut",
......@@ -410,6 +437,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
activation_grad);
fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
......@@ -456,10 +485,6 @@ bool FuseGemmEpiloguePass::IsGemmFromLinear_(
if (tmp_vec.size() > 0) return false;
}
}
if (BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_x")) ||
BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_y")))
return false;
return true;
}
......
......@@ -208,6 +208,9 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto trans_x = ctx->Attrs().Get<bool>("trans_x");
auto trans_y = ctx->Attrs().Get<bool>("trans_y");
PADDLE_ENFORCE_GE(
dout_dims.size(), 2,
platform::errors::InvalidArgument(
......@@ -242,14 +245,14 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel {
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1);
PADDLE_ENFORCE_EQ(
dout_mat_dims[1], y_dims[1],
dout_mat_dims[1], trans_y ? y_dims[0] : y_dims[1],
platform::errors::InvalidArgument(
"The last dimension of DOut should be equal with Y's last"
"dimension. But received DOut[-1] = [%d], Y[1] = [%d].",
dout_mat_dims[1], y_dims[1]));
PADDLE_ENFORCE_EQ(
dout_mat_dims[0], x_mat_dims[0],
dout_mat_dims[0], trans_x ? x_mat_dims[1] : x_mat_dims[0],
platform::errors::InvalidArgument(
"The first dimension of DOut should be equal with X's first"
"dimension. But received DOut[0] = [%d], Y[0] = [%d].",
......@@ -323,6 +326,8 @@ class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("DBias",
"The output grad tensor to bias of Out = (Act(X) * Y) + bias.")
.AsDispensable();
AddAttr<bool>("trans_x", "").SetDefault(false);
AddAttr<bool>("trans_y", "").SetDefault(false);
AddAttr<std::string>(
"activation_grad",
......
......@@ -40,6 +40,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation");
VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
<< " , activation = " << activation;
// activation = "none";
bool enable_auxiliary = reserve_space == nullptr ? false : true;
out->mutable_data<T>(ctx.GetPlace());
......@@ -56,7 +59,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
scale_type = CUDA_R_32F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
......@@ -106,10 +109,12 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data)));
// int64_t aux_ld = trans_y ? K : N;
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N,
sizeof(N)));
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
sizeof(aux_ld)));
}
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
......@@ -129,7 +134,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace =
......@@ -192,20 +197,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
std::string activation_grad = ctx.Attr<std::string>("activation_grad");
auto dout_mat_dims =
phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1);
auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1);
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
int64_t M = x_mat_dims[0];
int64_t K = y->dims()[0];
int64_t N = y->dims()[1];
VLOG(10) << "trans_x = " << transpose_x << " , trans_y = " << transpose_y
<< " , activation_grad = " << activation_grad;
// activation_grad = "none";
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), transpose_x ? 1 : x->dims().size() - 1);
int64_t M = transpose_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = transpose_y ? y->dims()[1] : y->dims()[0];
int64_t N = transpose_y ? y->dims()[0] : y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
scale_type = CUDA_R_32F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
......@@ -214,7 +226,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
......@@ -229,16 +241,54 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
beta = &beta32;
}
cublasOperation_t trans_dout = CUBLAS_OP_N;
cublasLtMatrixLayout_t dout_desc = NULL;
cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
if (dx) {
cublasOperation_t trans_dout = transpose_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t trans_y =
(transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasLtMatrixLayout_t dout_desc_for_dx, y_desc, dx_desc;
if (trans_dout == CUBLAS_OP_T) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
mat_type, M, N, M));
dout_desc_for_dx = dout_trans_desc;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc, mat_type,
N, M, N));
dout_desc_for_dx = dout_desc;
}
if (transpose_y) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K,
N, K));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N,
K, N));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dout_desc, mat_type, N, M, N));
&dx_desc, mat_type, K, M, K));
if (dx) {
cublasLtMatmulDesc_t dx_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
cublasOperation_t trans_y = CUBLAS_OP_T;
if (transpose_x) {
// dx = B * dout
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_y,
sizeof(trans_y)));
} else {
// dx = dout * B
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
......@@ -247,6 +297,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y,
sizeof(trans_y)));
}
cublasLtEpilogue_t epiloque_func_for_dx =
get_epilogue_type_(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -260,18 +312,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data, sizeof(aux_data)));
int64_t aux_ld = transpose_x ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N,
sizeof(N)));
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld, sizeof(aux_ld)));
}
cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, K, M, K));
memory::allocation::AllocationPtr dx_workspace =
memory::Alloc(dev_ctx, workspace_size);
......@@ -284,10 +331,41 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
if (dy) {
cublasOperation_t trans_dout = transpose_y ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t trans_x =
(transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasLtMatrixLayout_t dout_desc_for_dx;
if (trans_dout == CUBLAS_OP_T) {
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
mat_type, M, N, M));
}
dout_desc_for_dx = dout_trans_desc;
} else {
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc,
mat_type, N, M, N));
}
dout_desc_for_dx = dout_desc;
}
cublasLtMatmulDesc_t dy_operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
cublasOperation_t trans_x = CUBLAS_OP_T;
if (transpose_y) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
sizeof(trans_dout)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_x,
sizeof(trans_x)));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
......@@ -296,9 +374,13 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x,
sizeof(trans_x)));
cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr
? CUBLASLT_EPILOGUE_DEFAULT
: CUBLASLT_EPILOGUE_BGRADA;
}
cublasLtEpilogue_t epiloque_func_for_dy =
dbias == nullptr ? CUBLASLT_EPILOGUE_DEFAULT
: (transpose_y ? CUBLASLT_EPILOGUE_BGRADB
: CUBLASLT_EPILOGUE_BGRADA);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
......@@ -314,8 +396,16 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
}
cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
if (transpose_x) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M,
K, M));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K,
M, K));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, N, K, N));
......
......@@ -58,6 +58,14 @@ class MultiFCLayer(paddle.nn.Layer):
self.relu3 = Activation()
def forward(self, x, matmul_y, ele_y):
x = self.linear1(x)
x = self.relu1(x)
x = self.linear2(x)
x = self.relu2(x)
x = self.linear3(x)
x = self.relu3(x)
return x
'''
output = self.linear1(x)
output = self.relu1(output)
output = self.linear2(output)
......@@ -71,8 +79,10 @@ class MultiFCLayer(paddle.nn.Layer):
output = self.relu3(output)
output = paddle.add(output, output1)
return output
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
......@@ -218,6 +228,7 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -327,6 +338,7 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
return paddle.nn.ReLU, "relu", "relu_grad"
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
......@@ -339,8 +351,8 @@ class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase):
def test_output(self):
self._test_output()
'''
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
......@@ -355,6 +367,7 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -371,6 +384,7 @@ class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase):
self._test_output()
'''
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
......@@ -385,7 +399,7 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
'''
if __name__ == "__main__":
np.random.seed(0)
......
......@@ -235,5 +235,6 @@ class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16):
if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0)
unittest.main()
......@@ -446,5 +446,6 @@ class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16):
if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0)
unittest.main()
......@@ -1470,7 +1470,7 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8):
return cos_sim
def linear(x, weight, bias=None, name=None):
def linear(x, weight, bias=None, name=None, weight_transpose=False):
r"""
Fully-connected linear transformation operator. For each input :math:`X` ,
......@@ -1523,7 +1523,7 @@ def linear(x, weight, bias=None, name=None):
"""
if in_dynamic_mode():
pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y',
False)
weight_transpose)
if bias is None:
return pre_bias
......@@ -1538,7 +1538,7 @@ def linear(x, weight, bias=None, name=None):
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
inputs = {'X': [x], 'Y': [weight]}
attrs = {'trans_x': False, 'trans_y': False}
attrs = {'trans_x': False, 'trans_y': weight_transpose}
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
......
......@@ -150,13 +150,15 @@ class Linear(Layer):
out_features,
weight_attr=None,
bias_attr=None,
name=None):
name=None,
weight_transpose=False):
super(Linear, self).__init__()
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.weight = self.create_parameter(
shape=[in_features, out_features],
shape=[out_features, in_features]
if weight_transpose else [in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
......@@ -165,11 +167,16 @@ class Linear(Layer):
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.weight_transpose = weight_transpose
self.name = name
def forward(self, input):
out = F.linear(
x=input, weight=self.weight, bias=self.bias, name=self.name)
x=input,
weight=self.weight,
bias=self.bias,
name=self.name,
weight_transpose=self.weight_transpose)
return out
def extra_repr(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册