diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index efad516cdbfe50f52a45195d59f315ad7eee67c5..922bf780add0bf8e834ef2a545d669ee404e492b 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -44,13 +44,14 @@ class MatMulXPUKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims = y->dims(); - XpuFcInfo fc_info; - GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); + phi::XpuFcInfo fc_info; + phi::GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); auto& dev_ctx = context.template device_context(); xpu::Context* xpu_ctx = dev_ctx.x_context(); - MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); + phi::MatMulXPUFunction( + xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); } }; @@ -109,8 +110,8 @@ class MatMulGradXPUKernel : public framework::OpKernel { xpu::Context* xpu_ctx = dev_ctx.x_context(); - XpuFcInfo info_forward; - GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); + phi::XpuFcInfo info_forward; + phi::GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); xpu::ctx_guard RAII_GUARD(xpu_ctx); // begin calculate const XPUType* a_1 = reinterpret_cast(NULL); @@ -121,28 +122,28 @@ class MatMulGradXPUKernel : public framework::OpKernel { : reinterpret_cast(dx->data()); XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) : reinterpret_cast(dy->data()); - XpuFcInfo info_dx; - XpuFcInfo info_dy; - std::tuple - fc_info = MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - info_forward, - transpose_x, - transpose_y, - x_ptr, - y_ptr, - dout_ptr); + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { - MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); + phi::MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); } if (dy) { - MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); + phi::MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); } } }; diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc deleted file mode 100644 index 7b4195c1c19fa2a7a489c84af7826dec1ff5cd46..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_XPU - -#include -#include -#include "paddle/fluid/operators/matmul_v2_op.h" - -#include "paddle/fluid/operators/xpu_api_wrapper.h" - -namespace paddle { -namespace operators { - -template -class MatMulV2XPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); - out->mutable_data(ctx.GetPlace()); - const XPUType* x_ptr = reinterpret_cast(x->data()); - const XPUType* y_ptr = reinterpret_cast(y->data()); - XPUType* out_ptr = reinterpret_cast(out->data()); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - - XpuFcInfo fc_info; - GetFCInfo(x_dims, y_dims, trans_x, trans_y, &fc_info); - auto& dev_ctx = - ctx.template device_context(); - xpu::Context* xpu_ctx = dev_ctx.x_context(); - MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); - } -}; - -template -class MatMulV2XPUGradKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - bool transpose_x = context.Attr("trans_x"); - bool transpose_y = context.Attr("trans_y"); - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = - *context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - auto* dy = context.Output(framework::GradVarName("Y")); - if (dx) { - dx->mutable_data(context.GetPlace()); - } - if (dy) { - dy->mutable_data(context.GetPlace()); - } - auto& dev_ctx = - context.template device_context(); - - const XPUType* dout_ptr = reinterpret_cast(dout.data()); - const XPUType* x_ptr = reinterpret_cast(x.data()); - const XPUType* y_ptr = reinterpret_cast(y.data()); - - xpu::Context* xpu_ctx = dev_ctx.x_context(); - - XpuFcInfo info_forward; - GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - // begin calculate - const XPUType* a_1 = reinterpret_cast(NULL); - const XPUType* b_1 = reinterpret_cast(NULL); - const XPUType* a_2 = reinterpret_cast(NULL); - const XPUType* b_2 = reinterpret_cast(NULL); - XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) - : reinterpret_cast(dx->data()); - XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) - : reinterpret_cast(dy->data()); - XpuFcInfo info_dx; - XpuFcInfo info_dy; - std::tuple - fc_info = MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - info_forward, - transpose_x, - transpose_y, - x_ptr, - y_ptr, - dout_ptr); - std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; - if (dx) { - MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); - } - if (dy) { - MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_XPU_KERNEL(matmul_v2, - ops::MatMulV2XPUKernel, - ops::MatMulV2XPUKernel); -REGISTER_OP_XPU_KERNEL(matmul_v2_grad, - ops::MatMulV2XPUGradKernel, - ops::MatMulV2XPUGradKernel); - -#endif diff --git a/paddle/fluid/operators/mul_op_xpu.cc b/paddle/fluid/operators/mul_op_xpu.cc index 727a7c0f6e52ccaee85b2674e7511559561f1436..82ea3b5aa9be268199d94486ed83ee6fd0362529 100644 --- a/paddle/fluid/operators/mul_op_xpu.cc +++ b/paddle/fluid/operators/mul_op_xpu.cc @@ -59,13 +59,14 @@ class MulXPUKernel : public framework::OpKernel { auto x_dims = x_matrix.dims(); auto y_dims = y_matrix.dims(); - XpuFcInfo fc_info; - GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info); + phi::XpuFcInfo fc_info; + phi::GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info); auto& dev_ctx = context.template device_context(); xpu::Context* xpu_ctx = dev_ctx.x_context(); - MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); + phi::MatMulXPUFunction( + xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); } }; @@ -99,8 +100,9 @@ class MulGradXPUKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - XpuFcInfo info_forward; - GetFCInfo(x_matrix.dims(), y_matrix.dims(), false, false, &info_forward); + phi::XpuFcInfo info_forward; + phi::GetFCInfo( + x_matrix.dims(), y_matrix.dims(), false, false, &info_forward); const XPUType* dout_ptr = reinterpret_cast(dout->data()); const XPUType* x_ptr = reinterpret_cast(x->data()); @@ -121,28 +123,28 @@ class MulGradXPUKernel : public framework::OpKernel { (dy == NULL) ? reinterpret_cast(NULL) : reinterpret_cast(dy->mutable_data(ctx.GetPlace())); - XpuFcInfo info_dx; - XpuFcInfo info_dy; - std::tuple - fc_info = MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - info_forward, - false, - false, - x_ptr, - y_ptr, - dout_ptr); + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + false, + false, + x_ptr, + y_ptr, + dout_ptr); std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; if (dx) { - MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); + phi::MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); } if (dy) { - MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); + phi::MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); } } }; diff --git a/paddle/fluid/operators/xpu_api_wrapper.h b/paddle/fluid/operators/xpu_api_wrapper.h index c85a765f3b6fd31e0f84e4da99b02fb03adf5914..c23fb1ae02ab4fb2d2bbdcbb6c3e79a5d615da9a 100644 --- a/paddle/fluid/operators/xpu_api_wrapper.h +++ b/paddle/fluid/operators/xpu_api_wrapper.h @@ -11,600 +11,13 @@ limitations under the License. */ #pragma once #ifdef PADDLE_WITH_XPU -#include -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace paddle { namespace operators { using float16 = typename XPUTypeTrait::Type; -enum XPUFCCalcType { - FC_INT16 = 0, - FC_INT32, - FC_FLOAT, -}; - -template -XPUFCCalcType FCCalcType() { - if (std::is_same::value || - std::is_same::value) { - return XPUFCCalcType::FC_INT16; - } else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { - return XPUFCCalcType::FC_INT32; - } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { - return XPUFCCalcType::FC_FLOAT; - } - return XPUFCCalcType::FC_INT16; -} - -struct XpuFcInfo { - int bs; - int m; - int n; - int k; - bool trans_x; - bool trans_y; - int stride_x; - int stride_y; - int stride_out; - float* max_x; - float* max_y; - float* max_out; - XpuFcInfo() - : bs(0), - m(0), - n(0), - k(0), - trans_x(false), - trans_y(false), - stride_x(0), - stride_y(0), - stride_out(0), - max_x(nullptr), - max_y(nullptr), - max_out(nullptr) {} - void InitFcInfo(int bs, - int m, - int n, - int k, - bool trans_x, - bool trans_y, - float* max_x, - float* max_y, - float* max_out) { - this->bs = bs; - this->m = m; - this->n = n; - this->k = k; - this->trans_x = trans_x; - this->trans_y = trans_y; - this->max_x = max_x; - this->max_y = max_y; - this->max_out = max_out; - - if (this->bs <= 1) { - this->stride_x = trans_x ? m : k; - this->stride_y = trans_y ? k : n; - this->stride_out = n; - } else { - this->stride_x = m * k; - this->stride_y = k * n; - this->stride_out = m * n; - } - } -}; - -static std::ostream& operator<<(std::ostream& os, const XpuFcInfo& fc_inf) { - os << "fc_inf[ bs, m, n, k, trans_x, trans_y, stride_x, stride_y, " - "stride_out] = " - << "[" << fc_inf.bs << ", " << fc_inf.m << ", " << fc_inf.n << ", " - << fc_inf.k << ", " << fc_inf.trans_x << ", " << fc_inf.trans_y << ", " - << fc_inf.stride_x << ", " << fc_inf.stride_y << ", " << fc_inf.stride_out; - return os; -} - -static void GetFCInfo(const phi::DDim& x_dims, - const phi::DDim& y_dims, - bool trans_x, - bool trans_y, - XpuFcInfo* info) { - framework::DDim new_x_dims = - (x_dims.size() > 1) ? x_dims : phi::make_ddim({1, x_dims[0]}); - framework::DDim new_y_dims = - (y_dims.size() > 1) ? y_dims : phi::make_ddim({y_dims[0], 1}); - - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(new_x_dims, 0, trans_x); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y); - - if (x_dims.size() >= 3 && y_dims.size() <= 2) { - if (!trans_x) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } else { - mat_dim_b.batch_size_ = mat_dim_a.batch_size_; - mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; - } - } - - if (y_dims.size() >= 3 && x_dims.size() <= 2) { - PADDLE_ENFORCE_EQ( - mat_dim_b.trans_, - false, - platform::errors::InvalidArgument( - "xpu not support this Shape in matmul_op xdims = %s ydims = %s " - "x_trans = %d y_trans = %d", - x_dims.to_str(), - y_dims.to_str(), - mat_dim_a.trans_, - mat_dim_b.trans_)); - mat_dim_b.height_ *= mat_dim_b.batch_size_; - mat_dim_b.batch_size_ = 0; - } - - if (mat_dim_a.width_ == mat_dim_b.height_) { - if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { - mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; - } - } - - PADDLE_ENFORCE_EQ(mat_dim_a.width_, - mat_dim_b.height_, - platform::errors::InvalidArgument( - "Shape mistake in matmul_op xdims = %s ydims = %s " - "x_trans = %d y_trans = %d", - x_dims.to_str(), - y_dims.to_str(), - mat_dim_a.trans_, - mat_dim_b.trans_)); - - info->m = mat_dim_a.height_; - info->n = mat_dim_b.width_; - info->k = mat_dim_a.width_; - info->bs = mat_dim_a.batch_size_; - info->trans_x = trans_x; - info->trans_y = trans_y; - - if (info->bs <= 1) { - info->stride_x = trans_x ? info->m : info->k; - info->stride_y = trans_y ? info->k : info->n; - info->stride_out = info->n; - } else { - info->stride_x = info->m * info->k; - info->stride_y = info->k * info->n; - info->stride_out = info->m * info->n; - } -} - -template -static void xpu_fc_wrapper(xpu::Context* ctx, - const XPUType* x, - const XPUType* w, - XPUType* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = 0; - if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && - std::is_same::value) { - XPUType* l3_addr = nullptr; - xpu::ctx_guard RAII_GUARD(ctx); - l3_addr = RAII_GUARD.alloc_l3_or_gm(m * k); - PADDLE_ENFORCE_XDNN_NOT_NULL(l3_addr); - - std::vector shape = {k, m}; - std::vector axis = {1, 0}; - r = xpu::transpose(ctx, x, l3_addr, shape, axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - r = xpu::fc_fusion(ctx, - l3_addr, - w, - y, - m, - n, - k, - false, - w_trans, - x_maxptr, - w_maxptr, - y_maxptr, - k, - ldw, - ldy, - alpha, - beta, - bias, - act); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); - } else { - r = xpu::fc_fusion(ctx, - x, - w, - y, - m, - n, - k, - x_trans, - w_trans, - x_maxptr, - w_maxptr, - y_maxptr, - ldx, - ldw, - ldy, - alpha, - beta, - bias, - act); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); - } -} - -template <> -void xpu_fc_wrapper(xpu::Context* ctx, - const float16* x, - const float16* w, - float16* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); -} - -template -static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUType* x, - int stride_x, - const XPUType* w, - int stride_w, - float beta, - XPUType* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::fc_batched( - xpu_ctx, // Context* ctx, - bs, // int batch_size, - trans_x, // bool x_trans, - trans_w, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x), // const TX* x, - stride_x, // int stride_a, - reinterpret_cast(w), // const TW* w, - stride_w, // int stride_b, - 0.0, // float beta, - reinterpret_cast(y), // TY* y, - stride_y, // int stride_c, - x_maxptr, // const float* x_maxptr, - w_maxptr); // const float* w_maxptr - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); -} - -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const float16* x, - int stride_x, - const float16* w, - int stride_w, - float beta, - float16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} - -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const float16* x, - int stride_x, - const float16* w, - int stride_w, - float beta, - float16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} - -template -static void MatMulXPUFunction(xpu::Context* xpu_ctx, - const T* x, - const T* y, - T* out, - const XpuFcInfo& fcinfo, - float alpha) { - using XPUType = typename XPUTypeTrait::Type; - using float16 = typename XPUTypeTrait::Type; - int fccal_type = FCCalcType(); - - decltype(&paddle::operators::xpu_fc_wrapper) - fc_api_list[3] = { - &paddle::operators::xpu_fc_wrapper, - &paddle::operators::xpu_fc_wrapper, - &paddle::operators::xpu_fc_wrapper, - }; - decltype(&paddle::operators::xpu_fc_batch_wrapper) - fc_batch_api_list[3] = { - &paddle::operators::xpu_fc_batch_wrapper, - &paddle::operators::xpu_fc_batch_wrapper, - &paddle::operators::xpu_fc_batch_wrapper, - }; - - auto fc_api = fc_api_list[fccal_type]; - auto fc_batch_api = fc_batch_api_list[fccal_type]; - - int m = fcinfo.m; - int n = fcinfo.n; - int k = fcinfo.k; - int batch_size = fcinfo.bs; - int ldx = fcinfo.stride_x; - int ldy = fcinfo.stride_y; - int ldout = fcinfo.stride_out; - bool trans_x = fcinfo.trans_x; - bool trans_y = fcinfo.trans_y; - float* max_x = fcinfo.max_x; - float* max_y = fcinfo.max_y; - float* max_out = fcinfo.max_out; - - if (batch_size <= 1) { - fc_api(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - reinterpret_cast(out), - m, - n, - k, - trans_x, - trans_y, - max_x, - max_y, - max_out, - ldx, - ldy, - ldout, - alpha, - 0, - nullptr, - xpu::Activation_t::LINEAR); - } else { - // batch matmul - fc_batch_api(xpu_ctx, // Context* ctx, - batch_size, // int batch_size, - trans_x, // bool x_trans, - trans_y, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x), // const TX* x, - ldx, // int stride_a, - reinterpret_cast(y), // const TW* w, - ldy, // int stride_b, - 0.0, // float beta, - reinterpret_cast(out), // TY* y, - ldout, // int stride_c, - max_x, // const float* x_maxptr, - max_y); // const float* w_maxptr - } -} - -template -static std::tuple -MatmulGradFcInfo(xpu::Context* xpu_ctx, - xpu::ctx_guard* RAII_GUARD, - const XpuFcInfo& dout_shape, - bool trans_x, - bool trans_y, - const T* x, - const T* y, - const T* dout) { - XpuFcInfo dx_shape, dy_shape; - const T* dx_a = NULL; - const T* dx_b = NULL; - const T* dy_a = NULL; - const T* dy_b = NULL; - bool copy_to_l3 = false; - float* max_dout = NULL; - int maxptr_size = xpu_ctx->max_ptr_size(); - uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size()); - int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); - int dx_size = bs * dout_shape.m * dout_shape.k; - int dy_size = bs * dout_shape.k * dout_shape.n; - int dout_size = bs * dout_shape.m * dout_shape.n; - if (trans_x && trans_y) { - copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T); - } else if (trans_x) { - copy_to_l3 = l3_size >= dout_size * sizeof(T); - } else if (trans_y) { - copy_to_l3 = l3_size >= dout_size * 2 * sizeof(T); - } else { - copy_to_l3 = l3_size >= (dout_size + dx_size) * sizeof(T); - } - - const T* dout_new = dout; - int r = 0; - if (copy_to_l3) { - T* dout_l3 = RAII_GUARD->alloc_l3(dout_size); - PADDLE_ENFORCE_XDNN_NOT_NULL(dout_l3); - if ((dout_shape.bs > 1) || ((dout_shape.bs <= 1) && - (FCCalcType() == XPUFCCalcType::FC_FLOAT))) { - r = xpu::copy(xpu_ctx, dout, dout_l3, dout_size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - dout_new = dout_l3; - } else { - max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); - PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); - - r = xpu::findmax_copy_fusion(xpu_ctx, dout, max_dout, dout_l3, dout_size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); - dout_new = dout_l3; - } - } else if (((dout_shape.bs <= 1) && - (FCCalcType() != XPUFCCalcType::FC_FLOAT))) { - max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); - PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); - r = xpu::findmax_copy_fusion( - xpu_ctx, dout, max_dout, reinterpret_cast(NULL), dout_size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); - } - - if (trans_x && trans_y) { - // dx = T(y) * T(dout) - dx_shape.InitFcInfo(dout_shape.bs, - dout_shape.k, - dout_shape.m, - dout_shape.n, - true, - true, - nullptr, - max_dout, - nullptr); - dx_a = y, dx_b = dout_new; - // dy = T(dout) * T(x) - dy_shape.InitFcInfo(dout_shape.bs, - dout_shape.n, - dout_shape.k, - dout_shape.m, - true, - true, - max_dout, - nullptr, - nullptr); - dy_a = dout_new, dy_b = x; - } else if (trans_x) { - // dx = y * T(dout) - dx_shape.InitFcInfo(dout_shape.bs, - dout_shape.k, - dout_shape.m, - dout_shape.n, - false, - true, - nullptr, - max_dout, - nullptr); - dx_a = y, dx_b = dout_new; - // dy = x * dout - dy_shape.InitFcInfo(dout_shape.bs, - dout_shape.k, - dout_shape.n, - dout_shape.m, - false, - false, - nullptr, - max_dout, - nullptr); - dy_a = x, dy_b = dout_new; - } else if (trans_y) { - // dx = dout * y - dx_shape.InitFcInfo(dout_shape.bs, - dout_shape.m, - dout_shape.k, - dout_shape.n, - false, - false, - max_dout, - nullptr, - nullptr); - dx_a = dout_new, dx_b = y; - // dy = T(dout) * x - dy_shape.InitFcInfo(dout_shape.bs, - dout_shape.n, - dout_shape.k, - dout_shape.m, - true, - false, - max_dout, - nullptr, - nullptr); - dy_a = dout_new, dy_b = x; - } else { - // dx = dout * T(y) - dx_shape.InitFcInfo(dout_shape.bs, - dout_shape.m, - dout_shape.k, - dout_shape.n, - false, - true, - max_dout, - nullptr, - nullptr); - dx_a = dout_new, dx_b = y; - // dy = T(x) * dout - dy_shape.InitFcInfo(dout_shape.bs, - dout_shape.k, - dout_shape.n, - dout_shape.m, - true, - false, - nullptr, - max_dout, - nullptr); - dy_a = x, dy_b = dout_new; - } - std::tuple - result = std::make_tuple(dx_shape, dy_shape, dx_a, dx_b, dy_a, dy_b); - - return result; -} - } // namespace operators } // namespace paddle #endif diff --git a/paddle/phi/kernels/xpu/matmul_grad_kernel.cc b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2dd5788701495ed55ce0582b35d49fb269b82fa --- /dev/null +++ b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matmul_grad_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy) { + using XPUType = typename XPUTypeTrait::Type; + + if (dx) { + dev_ctx.template Alloc(dx); + } + if (dy) { + dev_ctx.template Alloc(dy); + } + + const XPUType* dout_ptr = reinterpret_cast(dout.data()); + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + XpuFcInfo info_forward; + GetFCInfo(x.dims(), y.dims(), transpose_x, transpose_y, &info_forward); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + // begin calculate + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dx->data()); + XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dy->data()); + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; + if (dx) { + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); + } + if (dy) { + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(matmul_grad, + XPU, + ALL_LAYOUT, + phi::MatmulGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/matmul_kernel.cc b/paddle/phi/kernels/xpu/matmul_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..de20aef317bd3ce3ea49a2c954a4f6c5683bf5b5 --- /dev/null +++ b/paddle/phi/kernels/xpu/matmul_kernel.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { + +template +void MatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + dev_ctx.template Alloc(out); + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + XPUType* out_ptr = reinterpret_cast(out->data()); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + XpuFcInfo fc_info; + GetFCInfo(x_dims, y_dims, transpose_x, transpose_y, &fc_info); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + matmul, XPU, ALL_LAYOUT, phi::MatmulKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..0ebed7f449e3f26d814a96203c6bd5a0bebc84c0 --- /dev/null +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -0,0 +1,612 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_XPU + +#include +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/backends/xpu/xpu_info.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +using float16 = typename XPUTypeTrait::Type; + +enum XPUFCCalcType { + FC_INT16 = 0, + FC_INT32, + FC_FLOAT, +}; + +template +XPUFCCalcType FCCalcType() { + if (std::is_same::value || + std::is_same::value) { + return XPUFCCalcType::FC_INT16; + } else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { + return XPUFCCalcType::FC_INT32; + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + return XPUFCCalcType::FC_FLOAT; + } + return XPUFCCalcType::FC_INT16; +} + +struct XpuFcInfo { + int bs; + int m; + int n; + int k; + bool trans_x; + bool trans_y; + int stride_x; + int stride_y; + int stride_out; + float* max_x; + float* max_y; + float* max_out; + XpuFcInfo() + : bs(0), + m(0), + n(0), + k(0), + trans_x(false), + trans_y(false), + stride_x(0), + stride_y(0), + stride_out(0), + max_x(nullptr), + max_y(nullptr), + max_out(nullptr) {} + void InitFcInfo(int bs, + int m, + int n, + int k, + bool trans_x, + bool trans_y, + float* max_x, + float* max_y, + float* max_out) { + this->bs = bs; + this->m = m; + this->n = n; + this->k = k; + this->trans_x = trans_x; + this->trans_y = trans_y; + this->max_x = max_x; + this->max_y = max_y; + this->max_out = max_out; + + if (this->bs <= 1) { + this->stride_x = trans_x ? m : k; + this->stride_y = trans_y ? k : n; + this->stride_out = n; + } else { + this->stride_x = m * k; + this->stride_y = k * n; + this->stride_out = m * n; + } + } +}; + +static std::ostream& operator<<(std::ostream& os, const XpuFcInfo& fc_inf) { + os << "fc_inf[ bs, m, n, k, trans_x, trans_y, stride_x, stride_y, " + "stride_out] = " + << "[" << fc_inf.bs << ", " << fc_inf.m << ", " << fc_inf.n << ", " + << fc_inf.k << ", " << fc_inf.trans_x << ", " << fc_inf.trans_y << ", " + << fc_inf.stride_x << ", " << fc_inf.stride_y << ", " << fc_inf.stride_out; + return os; +} + +static void GetFCInfo(const phi::DDim& x_dims, + const phi::DDim& y_dims, + bool trans_x, + bool trans_y, + XpuFcInfo* info) { + DDim new_x_dims = + (x_dims.size() > 1) ? x_dims : phi::make_ddim({1, x_dims[0]}); + DDim new_y_dims = + (y_dims.size() > 1) ? y_dims : phi::make_ddim({y_dims[0], 1}); + + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(new_x_dims, 0, trans_x); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y); + + if (x_dims.size() >= 3 && y_dims.size() <= 2) { + if (!trans_x) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } else { + mat_dim_b.batch_size_ = mat_dim_a.batch_size_; + mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; + } + } + + if (y_dims.size() >= 3 && x_dims.size() <= 2) { + PADDLE_ENFORCE_EQ( + mat_dim_b.trans_, + false, + phi::errors::InvalidArgument( + "xpu not support this Shape in matmul_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), + y_dims.to_str(), + mat_dim_a.trans_, + mat_dim_b.trans_)); + mat_dim_b.height_ *= mat_dim_b.batch_size_; + mat_dim_b.batch_size_ = 0; + } + + if (mat_dim_a.width_ == mat_dim_b.height_) { + if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + } + + PADDLE_ENFORCE_EQ(mat_dim_a.width_, + mat_dim_b.height_, + phi::errors::InvalidArgument( + "Shape mistake in matmul_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), + y_dims.to_str(), + mat_dim_a.trans_, + mat_dim_b.trans_)); + + info->m = mat_dim_a.height_; + info->n = mat_dim_b.width_; + info->k = mat_dim_a.width_; + info->bs = mat_dim_a.batch_size_; + info->trans_x = trans_x; + info->trans_y = trans_y; + + if (info->bs <= 1) { + info->stride_x = trans_x ? info->m : info->k; + info->stride_y = trans_y ? info->k : info->n; + info->stride_out = info->n; + } else { + info->stride_x = info->m * info->k; + info->stride_y = info->k * info->n; + info->stride_out = info->m * info->n; + } +} + +template +static void xpu_fc_wrapper(xpu::Context* ctx, + const XPUType* x, + const XPUType* w, + XPUType* y, + int m, + int n, + int k, + bool x_trans, + bool w_trans, + const float* x_maxptr, + const float* w_maxptr, + float* y_maxptr, + int ldx, + int ldw, + int ldy, + float alpha, + float beta, + const float* bias, + const xpu::Activation_t& act) { + int r = 0; + if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && + std::is_same::value) { + XPUType* l3_addr = nullptr; + xpu::ctx_guard RAII_GUARD(ctx); + l3_addr = RAII_GUARD.alloc_l3_or_gm(m * k); + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_addr); + + std::vector shape = {k, m}; + std::vector axis = {1, 0}; + r = xpu::transpose(ctx, x, l3_addr, shape, axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + r = xpu::fc_fusion(ctx, + l3_addr, + w, + y, + m, + n, + k, + false, + w_trans, + x_maxptr, + w_maxptr, + y_maxptr, + k, + ldw, + ldy, + alpha, + beta, + bias, + act); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); + } else { + r = xpu::fc_fusion(ctx, + x, + w, + y, + m, + n, + k, + x_trans, + w_trans, + x_maxptr, + w_maxptr, + y_maxptr, + ldx, + ldw, + ldy, + alpha, + beta, + bias, + act); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); + } +} + +template <> +void xpu_fc_wrapper(xpu::Context* ctx, + const float16* x, + const float16* w, + float16* y, + int m, + int n, + int k, + bool x_trans, + bool w_trans, + const float* x_maxptr, + const float* w_maxptr, + float* y_maxptr, + int ldx, + int ldw, + int ldy, + float alpha, + float beta, + const float* bias, + const xpu::Activation_t& act) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); +} + +template +static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const XPUType* x, + int stride_x, + const XPUType* w, + int stride_w, + float beta, + XPUType* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::fc_batched( + xpu_ctx, // Context* ctx, + bs, // int batch_size, + trans_x, // bool x_trans, + trans_w, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x), // const TX* x, + stride_x, // int stride_a, + reinterpret_cast(w), // const TW* w, + stride_w, // int stride_b, + 0.0, // float beta, + reinterpret_cast(y), // TY* y, + stride_y, // int stride_c, + x_maxptr, // const float* x_maxptr, + w_maxptr); // const float* w_maxptr + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); +} + +template <> +void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const float16* x, + int stride_x, + const float16* w, + int stride_w, + float beta, + float16* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); +} + +template <> +void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const float16* x, + int stride_x, + const float16* w, + int stride_w, + float beta, + float16* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xpu::Error_t::INVALID_PARAM; + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); +} + +template +static void MatMulXPUFunction(xpu::Context* xpu_ctx, + const T* x, + const T* y, + T* out, + const XpuFcInfo& fcinfo, + float alpha) { + using XPUType = typename XPUTypeTrait::Type; + int fccal_type = FCCalcType(); + + decltype(&xpu_fc_wrapper) fc_api_list[3] = { + &xpu_fc_wrapper, + &xpu_fc_wrapper, + &xpu_fc_wrapper, + }; + decltype(&xpu_fc_batch_wrapper) fc_batch_api_list[3] = { + &xpu_fc_batch_wrapper, + &xpu_fc_batch_wrapper, + &xpu_fc_batch_wrapper, + }; + + auto fc_api = fc_api_list[fccal_type]; + auto fc_batch_api = fc_batch_api_list[fccal_type]; + + int m = fcinfo.m; + int n = fcinfo.n; + int k = fcinfo.k; + int batch_size = fcinfo.bs; + int ldx = fcinfo.stride_x; + int ldy = fcinfo.stride_y; + int ldout = fcinfo.stride_out; + bool trans_x = fcinfo.trans_x; + bool trans_y = fcinfo.trans_y; + float* max_x = fcinfo.max_x; + float* max_y = fcinfo.max_y; + float* max_out = fcinfo.max_out; + + if (batch_size <= 1) { + fc_api(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(out), + m, + n, + k, + trans_x, + trans_y, + max_x, + max_y, + max_out, + ldx, + ldy, + ldout, + alpha, + 0, + nullptr, + xpu::Activation_t::LINEAR); + } else { + // batch matmul + fc_batch_api(xpu_ctx, // Context* ctx, + batch_size, // int batch_size, + trans_x, // bool x_trans, + trans_y, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x), // const TX* x, + ldx, // int stride_a, + reinterpret_cast(y), // const TW* w, + ldy, // int stride_b, + 0.0, // float beta, + reinterpret_cast(out), // TY* y, + ldout, // int stride_c, + max_x, // const float* x_maxptr, + max_y); // const float* w_maxptr + } +} + +template +static std::tuple +MatmulGradFcInfo(xpu::Context* xpu_ctx, + xpu::ctx_guard* RAII_GUARD, + const XpuFcInfo& dout_shape, + bool trans_x, + bool trans_y, + const T* x, + const T* y, + const T* dout) { + XpuFcInfo dx_shape, dy_shape; + const T* dx_a = NULL; + const T* dx_b = NULL; + const T* dy_a = NULL; + const T* dy_b = NULL; + bool copy_to_l3 = false; + float* max_dout = NULL; + int maxptr_size = xpu_ctx->max_ptr_size(); + uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size()); + int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); + int dx_size = bs * dout_shape.m * dout_shape.k; + int dy_size = bs * dout_shape.k * dout_shape.n; + int dout_size = bs * dout_shape.m * dout_shape.n; + if (trans_x && trans_y) { + copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T); + } else if (trans_x) { + copy_to_l3 = l3_size >= dout_size * sizeof(T); + } else if (trans_y) { + copy_to_l3 = l3_size >= dout_size * 2 * sizeof(T); + } else { + copy_to_l3 = l3_size >= (dout_size + dx_size) * sizeof(T); + } + + const T* dout_new = dout; + int r = 0; + if (copy_to_l3) { + T* dout_l3 = RAII_GUARD->alloc_l3(dout_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(dout_l3); + if ((dout_shape.bs > 1) || ((dout_shape.bs <= 1) && + (FCCalcType() == XPUFCCalcType::FC_FLOAT))) { + r = xpu::copy(xpu_ctx, dout, dout_l3, dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + dout_new = dout_l3; + } else { + max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); + + r = xpu::findmax_copy_fusion(xpu_ctx, dout, max_dout, dout_l3, dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + dout_new = dout_l3; + } + } else if (((dout_shape.bs <= 1) && + (FCCalcType() != XPUFCCalcType::FC_FLOAT))) { + max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); + r = xpu::findmax_copy_fusion( + xpu_ctx, dout, max_dout, reinterpret_cast(NULL), dout_size); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + } + + if (trans_x && trans_y) { + // dx = T(y) * T(dout) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.m, + dout_shape.n, + true, + true, + nullptr, + max_dout, + nullptr); + dx_a = y, dx_b = dout_new; + // dy = T(dout) * T(x) + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.n, + dout_shape.k, + dout_shape.m, + true, + true, + max_dout, + nullptr, + nullptr); + dy_a = dout_new, dy_b = x; + } else if (trans_x) { + // dx = y * T(dout) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.m, + dout_shape.n, + false, + true, + nullptr, + max_dout, + nullptr); + dx_a = y, dx_b = dout_new; + // dy = x * dout + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.n, + dout_shape.m, + false, + false, + nullptr, + max_dout, + nullptr); + dy_a = x, dy_b = dout_new; + } else if (trans_y) { + // dx = dout * y + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.m, + dout_shape.k, + dout_shape.n, + false, + false, + max_dout, + nullptr, + nullptr); + dx_a = dout_new, dx_b = y; + // dy = T(dout) * x + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.n, + dout_shape.k, + dout_shape.m, + true, + false, + max_dout, + nullptr, + nullptr); + dy_a = dout_new, dy_b = x; + } else { + // dx = dout * T(y) + dx_shape.InitFcInfo(dout_shape.bs, + dout_shape.m, + dout_shape.k, + dout_shape.n, + false, + true, + max_dout, + nullptr, + nullptr); + dx_a = dout_new, dx_b = y; + // dy = T(x) * dout + dy_shape.InitFcInfo(dout_shape.bs, + dout_shape.k, + dout_shape.n, + dout_shape.m, + true, + false, + nullptr, + max_dout, + nullptr); + dy_a = x, dy_b = dout_new; + } + std::tuple + result = std::make_tuple(dx_shape, dy_shape, dx_a, dx_b, dy_a, dy_b); + + return result; +} + +} // namespace phi +#endif