From e512aa9a4be2b4bcf55e5b4b2102ff133f898a95 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Tue, 2 Nov 2021 18:42:52 +0800 Subject: [PATCH] support different precision in kunlun (#36836) * support different precision in kunlun * minor * minor * minor --- cmake/external/xpu.cmake | 3 +- paddle/fluid/operators/matmul_op_xpu.cc | 21 +++++---- paddle/fluid/operators/matmul_v2_op_xpu.cc | 20 ++++++-- paddle/fluid/operators/xpu_api_wrapper.h | 53 ++++++++++++++++++++++ paddle/fluid/platform/device_context.cc | 6 ++- paddle/fluid/platform/xpu/xpu2_op_list.h | 39 ++++++++++++++++ 6 files changed, 126 insertions(+), 16 deletions(-) create mode 100644 paddle/fluid/operators/xpu_api_wrapper.h diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 11a7adbbeb9..d12f51c82b2 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -35,7 +35,8 @@ ELSE () ENDIF() SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") -SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211020") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211029") +#SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211020") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 7097b5327d8..53593d2db01 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/xpu_api_wrapper.h" namespace paddle { namespace operators { @@ -151,28 +152,26 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, x_dims.to_str().c_str(), y_dims.to_str().c_str())); float alpha = static_cast(ctx.Attr("alpha")); - T *data_c = out->data(); int m = mat_dim_a.height_; int n = mat_dim_b.width_; int k = mat_dim_a.width_; int batch_size = mat_dim_a.batch_size_; - int ldx = mat_dim_a.trans_ ? m : k; int ldy = mat_dim_b.trans_ ? k : n; int ldout = n; if (batch_size <= 1) { int r = 0; - r = xpu::fc_fusion( + r = xpu_fc_wrapper( dev_ctx.x_context(), reinterpret_cast(x->data()), reinterpret_cast(y->data()), reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU fc_fusion kernel return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU fc kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } else { // batch matmul int r = xpu::fc_batched( @@ -216,8 +215,10 @@ class MatMulXPUKernel : public framework::OpKernel { if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, context); } else { - if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, context); } else { MatMulXPUFunction(x, y, out, trans_x, trans_y, context); } @@ -292,8 +293,10 @@ class MatMulGradXPUKernel : public framework::OpKernel { if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); } else { - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); } else { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); } diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index ae1e9358f68..908a23c4ecc 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -18,6 +18,8 @@ #include #include +#include "paddle/fluid/operators/xpu_api_wrapper.h" + namespace paddle { namespace operators { @@ -74,17 +76,21 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, int n = mat_dim_b.width_; int k = mat_dim_a.width_; int batch_size = mat_dim_a.batch_size_; + int ldx = mat_dim_a.trans_ ? m : k; + int ldy = mat_dim_b.trans_ ? k : n; + int ldout = n; if (batch_size <= 1) { int r = 0; - r = xpu::fc( + r = xpu_fc_wrapper( dev_ctx.x_context(), reinterpret_cast(x->data()), reinterpret_cast(y->data()), reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, - mat_dim_b.trans_, nullptr, nullptr, nullptr); + mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, 1.0, 0, + nullptr, xpu::Activation_t::LINEAR); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( - "XPU fc_fusion kernel return wrong value[%d %s] , m = %d, n = " + "XPU fc kernel return wrong value[%d %s] , m = %d, n = " "%d, " "k = %d, a_tr = %d, b_tr = %d", r, XPUAPIErrorMsg[r], m, n, k, mat_dim_a.trans_, mat_dim_b.trans_)); @@ -129,8 +135,10 @@ class MatMulV2XPUKernel : public framework::OpKernel { if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); } else { - if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); } else { MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); } @@ -178,8 +186,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); } else { - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) { + if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); } else { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); } diff --git a/paddle/fluid/operators/xpu_api_wrapper.h b/paddle/fluid/operators/xpu_api_wrapper.h new file mode 100644 index 00000000000..4fdb33ca6c4 --- /dev/null +++ b/paddle/fluid/operators/xpu_api_wrapper.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#ifdef PADDLE_WITH_XPU +#include + +namespace paddle { +namespace operators { + +template +int xpu_fc_wrapper(xpu::Context* ctx, const XPUType* x, const XPUType* w, + XPUType* y, int m, int n, int k, bool x_trans, bool w_trans, + const float* x_maxptr, const float* w_maxptr, + float* y_maxptr, int ldx, int ldw, int ldy, float alpha, + float beta, const float* bias, + const xpu::Activation_t& act) { + int r = 0; + if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && + std::is_same::value) { + XPUType* l3_addr = nullptr; + xpu::ctx_guard RAII_GUARD(ctx); + l3_addr = RAII_GUARD.alloc_l3_or_gm(m * k); + if (l3_addr == nullptr) return XPUERR_NOMEM; + + std::vector shape = {k, m}; + std::vector axis = {1, 0}; + r = xpu::transpose(ctx, x, l3_addr, shape, axis); + if (r != XPU_SUCCESS) return r; + + r = xpu::fc_fusion( + ctx, l3_addr, w, y, m, n, k, false, w_trans, x_maxptr, w_maxptr, + y_maxptr, k, ldw, ldy, alpha, beta, bias, act); + if (r != XPU_SUCCESS) return r; + } else { + r = xpu::fc_fusion( + ctx, x, w, y, m, n, k, x_trans, w_trans, x_maxptr, w_maxptr, y_maxptr, + ldx, ldw, ldy, alpha, beta, bias, act); + } + return r; +} + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index d934918d018..cc3aab3ecdb 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -222,9 +222,13 @@ XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) { context_ = xpu::create_context(); const int MAX_XPU_NUM = 16; - const int l3_size = 13.5 * 1024 * 1024; static void* l3ptrs[MAX_XPU_NUM] = {nullptr}; + int l3_size = 13.5 * 1024 * 1024; + if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) { + l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE")); + } + auto selected_xpus = GetXPUSelectedDevices(); for (unsigned int i = 0; i < selected_xpus.size(); i++) { if (place.device == selected_xpus[i]) { diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 389166c0005..d1a3bb5dd3c 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -90,6 +90,12 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"softmax_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_with_cross_entropy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_with_cross_entropy_grad", @@ -171,6 +177,39 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"assign_value", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"dropout_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"elementwise_div", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"elementwise_div_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reshape2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"layer_norm_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lookup_table_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lookup_table_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"flatten_contiguous_range", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), -- GitLab