/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #ifdef PADDLE_WITH_XPU #include #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { using float16 = typename XPUTypeTrait::Type; enum XPUFCCalcType { FC_INT16 = 0, FC_INT32, FC_FLOAT, }; template XPUFCCalcType FCCalcType() { if (std::is_same::value || std::is_same::value) { return XPUFCCalcType::FC_INT16; } else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { return XPUFCCalcType::FC_INT32; } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { return XPUFCCalcType::FC_FLOAT; } return XPUFCCalcType::FC_INT16; } struct XpuFcInfo { int bs; int m; int n; int k; bool trans_x; bool trans_y; int stride_x; int stride_y; int stride_out; float* max_x; float* max_y; float* max_out; XpuFcInfo() : bs(0), m(0), n(0), k(0), trans_x(false), trans_y(false), stride_x(0), stride_y(0), stride_out(0), max_x(nullptr), max_y(nullptr), max_out(nullptr) {} void InitFcInfo(int bs, int m, int n, int k, bool trans_x, bool trans_y, float* max_x, float* max_y, float* max_out) { this->bs = bs; this->m = m; this->n = n; this->k = k; this->trans_x = trans_x; this->trans_y = trans_y; this->max_x = max_x; this->max_y = max_y; this->max_out = max_out; if (this->bs <= 1) { this->stride_x = trans_x ? m : k; this->stride_y = trans_y ? k : n; this->stride_out = n; } else { this->stride_x = m * k; this->stride_y = k * n; this->stride_out = m * n; } } }; static std::ostream& operator<<(std::ostream& os, const XpuFcInfo& fc_inf) { os << "fc_inf[ bs, m, n, k, trans_x, trans_y, stride_x, stride_y, " "stride_out] = " << "[" << fc_inf.bs << ", " << fc_inf.m << ", " << fc_inf.n << ", " << fc_inf.k << ", " << fc_inf.trans_x << ", " << fc_inf.trans_y << ", " << fc_inf.stride_x << ", " << fc_inf.stride_y << ", " << fc_inf.stride_out; return os; } static void GetFCInfo(const phi::DDim& x_dims, const phi::DDim& y_dims, bool trans_x, bool trans_y, XpuFcInfo* info) { framework::DDim new_x_dims = (x_dims.size() > 1) ? x_dims : phi::make_ddim({1, x_dims[0]}); framework::DDim new_y_dims = (y_dims.size() > 1) ? y_dims : phi::make_ddim({y_dims[0], 1}); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(new_x_dims, 0, trans_x); auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y); if (x_dims.size() >= 3 && y_dims.size() <= 2) { if (!trans_x) { mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; } else { mat_dim_b.batch_size_ = mat_dim_a.batch_size_; mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; } } if (y_dims.size() >= 3 && x_dims.size() <= 2) { PADDLE_ENFORCE_EQ( mat_dim_b.trans_, false, platform::errors::InvalidArgument( "xpu not support this Shape in matmul_op xdims = %s ydims = %s " "x_trans = %d y_trans = %d", x_dims.to_str(), y_dims.to_str(), mat_dim_a.trans_, mat_dim_b.trans_)); mat_dim_b.height_ *= mat_dim_b.batch_size_; mat_dim_b.batch_size_ = 0; } if (mat_dim_a.width_ == mat_dim_b.height_) { if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; } if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; } } PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_, platform::errors::InvalidArgument( "Shape mistake in matmul_op xdims = %s ydims = %s " "x_trans = %d y_trans = %d", x_dims.to_str(), y_dims.to_str(), mat_dim_a.trans_, mat_dim_b.trans_)); info->m = mat_dim_a.height_; info->n = mat_dim_b.width_; info->k = mat_dim_a.width_; info->bs = mat_dim_a.batch_size_; info->trans_x = trans_x; info->trans_y = trans_y; if (info->bs <= 1) { info->stride_x = trans_x ? info->m : info->k; info->stride_y = trans_y ? info->k : info->n; info->stride_out = info->n; } else { info->stride_x = info->m * info->k; info->stride_y = info->k * info->n; info->stride_out = info->m * info->n; } } template static void xpu_fc_wrapper(xpu::Context* ctx, const XPUType* x, const XPUType* w, XPUType* y, int m, int n, int k, bool x_trans, bool w_trans, const float* x_maxptr, const float* w_maxptr, float* y_maxptr, int ldx, int ldw, int ldy, float alpha, float beta, const float* bias, const xpu::Activation_t& act) { int r = 0; if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && std::is_same::value) { XPUType* l3_addr = nullptr; xpu::ctx_guard RAII_GUARD(ctx); l3_addr = RAII_GUARD.alloc_l3_or_gm(m * k); PADDLE_ENFORCE_XDNN_NOT_NULL(l3_addr); std::vector shape = {k, m}; std::vector axis = {1, 0}; r = xpu::transpose(ctx, x, l3_addr, shape, axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); r = xpu::fc_fusion(ctx, l3_addr, w, y, m, n, k, false, w_trans, x_maxptr, w_maxptr, y_maxptr, k, ldw, ldy, alpha, beta, bias, act); PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); } else { r = xpu::fc_fusion(ctx, x, w, y, m, n, k, x_trans, w_trans, x_maxptr, w_maxptr, y_maxptr, ldx, ldw, ldy, alpha, beta, bias, act); PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); } } template <> void xpu_fc_wrapper(xpu::Context* ctx, const float16* x, const float16* w, float16* y, int m, int n, int k, bool x_trans, bool w_trans, const float* x_maxptr, const float* w_maxptr, float* y_maxptr, int ldx, int ldw, int ldy, float alpha, float beta, const float* bias, const xpu::Activation_t& act) { int r = xpu::Error_t::INVALID_PARAM; PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); } template static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, int bs, bool trans_x, bool trans_w, int m, int n, int k, float alpha, const XPUType* x, int stride_x, const XPUType* w, int stride_w, float beta, XPUType* y, int stride_y, const float* x_maxptr, const float* w_maxptr) { int r = xpu::fc_batched( xpu_ctx, // Context* ctx, bs, // int batch_size, trans_x, // bool x_trans, trans_w, // bool w_trans, m, // int m, n, // int n, k, // int k, alpha, // float alpha, reinterpret_cast(x), // const TX* x, stride_x, // int stride_a, reinterpret_cast(w), // const TW* w, stride_w, // int stride_b, 0.0, // float beta, reinterpret_cast(y), // TY* y, stride_y, // int stride_c, x_maxptr, // const float* x_maxptr, w_maxptr); // const float* w_maxptr PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); } template <> void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, int bs, bool trans_x, bool trans_w, int m, int n, int k, float alpha, const float16* x, int stride_x, const float16* w, int stride_w, float beta, float16* y, int stride_y, const float* x_maxptr, const float* w_maxptr) { int r = xpu::Error_t::INVALID_PARAM; PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); } template <> void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, int bs, bool trans_x, bool trans_w, int m, int n, int k, float alpha, const float16* x, int stride_x, const float16* w, int stride_w, float beta, float16* y, int stride_y, const float* x_maxptr, const float* w_maxptr) { int r = xpu::Error_t::INVALID_PARAM; PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); } template static void MatMulXPUFunction(xpu::Context* xpu_ctx, const T* x, const T* y, T* out, const XpuFcInfo& fcinfo, float alpha) { using XPUType = typename XPUTypeTrait::Type; using float16 = typename XPUTypeTrait::Type; int fccal_type = FCCalcType(); decltype(&paddle::operators::xpu_fc_wrapper) fc_api_list[3] = { &paddle::operators::xpu_fc_wrapper, &paddle::operators::xpu_fc_wrapper, &paddle::operators::xpu_fc_wrapper, }; decltype(&paddle::operators::xpu_fc_batch_wrapper) fc_batch_api_list[3] = { &paddle::operators::xpu_fc_batch_wrapper, &paddle::operators::xpu_fc_batch_wrapper, &paddle::operators::xpu_fc_batch_wrapper, }; auto fc_api = fc_api_list[fccal_type]; auto fc_batch_api = fc_batch_api_list[fccal_type]; int m = fcinfo.m; int n = fcinfo.n; int k = fcinfo.k; int batch_size = fcinfo.bs; int ldx = fcinfo.stride_x; int ldy = fcinfo.stride_y; int ldout = fcinfo.stride_out; bool trans_x = fcinfo.trans_x; bool trans_y = fcinfo.trans_y; float* max_x = fcinfo.max_x; float* max_y = fcinfo.max_y; float* max_out = fcinfo.max_out; if (batch_size <= 1) { fc_api(xpu_ctx, reinterpret_cast(x), reinterpret_cast(y), reinterpret_cast(out), m, n, k, trans_x, trans_y, max_x, max_y, max_out, ldx, ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); } else { // batch matmul fc_batch_api(xpu_ctx, // Context* ctx, batch_size, // int batch_size, trans_x, // bool x_trans, trans_y, // bool w_trans, m, // int m, n, // int n, k, // int k, alpha, // float alpha, reinterpret_cast(x), // const TX* x, ldx, // int stride_a, reinterpret_cast(y), // const TW* w, ldy, // int stride_b, 0.0, // float beta, reinterpret_cast(out), // TY* y, ldout, // int stride_c, max_x, // const float* x_maxptr, max_y); // const float* w_maxptr } } template static std::tuple MatmulGradFcInfo(xpu::Context* xpu_ctx, xpu::ctx_guard* RAII_GUARD, const XpuFcInfo& dout_shape, bool trans_x, bool trans_y, const T* x, const T* y, const T* dout) { XpuFcInfo dx_shape, dy_shape; const T* dx_a = NULL; const T* dx_b = NULL; const T* dy_a = NULL; const T* dy_b = NULL; bool copy_to_l3 = false; float* max_dout = NULL; int maxptr_size = xpu_ctx->max_ptr_size(); uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size()); int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); int dx_size = bs * dout_shape.m * dout_shape.k; int dy_size = bs * dout_shape.k * dout_shape.n; int dout_size = bs * dout_shape.m * dout_shape.n; if (trans_x && trans_y) { copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T); } else if (trans_x) { copy_to_l3 = l3_size >= dout_size * sizeof(T); } else if (trans_y) { copy_to_l3 = l3_size >= dout_size * 2 * sizeof(T); } else { copy_to_l3 = l3_size >= (dout_size + dx_size) * sizeof(T); } const T* dout_new = dout; int r = 0; if (copy_to_l3) { T* dout_l3 = RAII_GUARD->alloc_l3(dout_size); PADDLE_ENFORCE_XDNN_NOT_NULL(dout_l3); if ((dout_shape.bs > 1) || ((dout_shape.bs <= 1) && (FCCalcType() == XPUFCCalcType::FC_FLOAT))) { r = xpu::copy(xpu_ctx, dout, dout_l3, dout_size); PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); dout_new = dout_l3; } else { max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); r = xpu::findmax_copy_fusion(xpu_ctx, dout, max_dout, dout_l3, dout_size); PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); dout_new = dout_l3; } } else if (((dout_shape.bs <= 1) && (FCCalcType() != XPUFCCalcType::FC_FLOAT))) { max_dout = RAII_GUARD->alloc_l3_or_gm(maxptr_size); PADDLE_ENFORCE_XDNN_NOT_NULL(max_dout); r = xpu::findmax_copy_fusion( xpu_ctx, dout, max_dout, reinterpret_cast(NULL), dout_size); PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); } if (trans_x && trans_y) { // dx = T(y) * T(dout) dx_shape.InitFcInfo(dout_shape.bs, dout_shape.k, dout_shape.m, dout_shape.n, true, true, nullptr, max_dout, nullptr); dx_a = y, dx_b = dout_new; // dy = T(dout) * T(x) dy_shape.InitFcInfo(dout_shape.bs, dout_shape.n, dout_shape.k, dout_shape.m, true, true, max_dout, nullptr, nullptr); dy_a = dout_new, dy_b = x; } else if (trans_x) { // dx = y * T(dout) dx_shape.InitFcInfo(dout_shape.bs, dout_shape.k, dout_shape.m, dout_shape.n, false, true, nullptr, max_dout, nullptr); dx_a = y, dx_b = dout_new; // dy = x * dout dy_shape.InitFcInfo(dout_shape.bs, dout_shape.k, dout_shape.n, dout_shape.m, false, false, nullptr, max_dout, nullptr); dy_a = x, dy_b = dout_new; } else if (trans_y) { // dx = dout * y dx_shape.InitFcInfo(dout_shape.bs, dout_shape.m, dout_shape.k, dout_shape.n, false, false, max_dout, nullptr, nullptr); dx_a = dout_new, dx_b = y; // dy = T(dout) * x dy_shape.InitFcInfo(dout_shape.bs, dout_shape.n, dout_shape.k, dout_shape.m, true, false, max_dout, nullptr, nullptr); dy_a = dout_new, dy_b = x; } else { // dx = dout * T(y) dx_shape.InitFcInfo(dout_shape.bs, dout_shape.m, dout_shape.k, dout_shape.n, false, true, max_dout, nullptr, nullptr); dx_a = dout_new, dx_b = y; // dy = T(x) * dout dy_shape.InitFcInfo(dout_shape.bs, dout_shape.k, dout_shape.n, dout_shape.m, true, false, nullptr, max_dout, nullptr); dy_a = x, dy_b = dout_new; } std::tuple result = std::make_tuple(dx_shape, dy_shape, dx_a, dx_b, dy_a, dy_b); return result; } } // namespace operators } // namespace paddle #endif