From 1a1a2ce8072250b96ca216161a21db9b40a6c136 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Wed, 23 Feb 2022 12:25:25 +0800 Subject: [PATCH] [KP] Add elementwise add xpu after phi, test=develop (#39787) * [KP] Add elementwise add xpu, test=develop * modify the File Permissions * modify the copyright time * modify code style * modify code style --- .pre-commit-config.yaml | 4 +- cmake/operators.cmake | 15 +- cmake/xpu_kp.cmake | 14 +- .../elementwise/elementwise_add_op.kps | 188 ++++++++++ paddle/fluid/platform/device_context.h | 3 + paddle/phi/core/hostdevice.h | 4 +- paddle/phi/kernels/funcs/broadcast_function.h | 14 +- paddle/phi/kernels/funcs/eigen/extensions.h | 4 + paddle/phi/kernels/funcs/elementwise_base.h | 20 +- paddle/phi/kernels/gpu/elementwise.h | 127 +++---- .../primitive/compute_primitives_xpu2.h | 4 +- .../primitive/datamover_primitives_xpu2.h | 46 +-- .../primitive/functor_primitives_xpu2.h | 209 +++++++++++ .../phi/kernels/primitive/helper_primitives.h | 2 +- .../phi/kernels/primitive/kernel_primitives.h | 11 +- .../xpu/test_elementwise_add_op_xpu_kp.py | 341 ++++++++++++++++++ 16 files changed, 890 insertions(+), 116 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_add_op.kps create mode 100755 paddle/phi/kernels/primitive/functor_primitives_xpu2.h create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_add_op_xpu_kp.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df2e59b7647..2684529930e 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: description: Format files with ClangFormat. entry: bash ./tools/codestyle/clang_format.hook -i language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps)$ - repo: local hooks: - id: cpplint-cpp-source @@ -48,7 +48,7 @@ repos: name: copyright_checker entry: python ./tools/codestyle/copyright.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$ exclude: | (?x)^( paddle/utils/.* diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 8469dc4c02e..8843dd26287 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -125,6 +125,9 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu) list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps) + list(APPEND xpu_kp_cc_srcs ${TARGET}.kps) + endif() endif() if(WITH_ASCEND_CL) string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}") @@ -162,6 +165,8 @@ function(op_library TARGET) list(APPEND xpu_cc_srcs ${src}) elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$") list(APPEND xpu_kp_cc_srcs ${src}) + elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.kps$") + list(APPEND xpu_kp_cc_srcs ${src}) elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") list(APPEND npu_cc_srcs ${src}) elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$") @@ -384,7 +389,15 @@ function(op_library TARGET) # pybind USE_OP_DEVICE_KERNEL for XPU KP if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, KP);\n") + foreach(xpu_kp_src ${xpu_kp_cc_srcs}) + set(op_name "") + find_register(${xpu_kp_src} "REGISTER_OP_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, KP);\n") + message(STATUS "Building KP Target: ${op_name}") + set(pybind_flag 1) + endif() + endforeach() endif() # pybind USE_OP_DEVICE_KERNEL for NPU diff --git a/cmake/xpu_kp.cmake b/cmake/xpu_kp.cmake index f8ab9693db0..adab3e1423c 100644 --- a/cmake/xpu_kp.cmake +++ b/cmake/xpu_kp.cmake @@ -17,7 +17,7 @@ if(NOT WITH_XPU_KP) endif() if(NOT XPU_TOOLCHAIN) - set(XPU_TOOLCHAIN /workspace/paddle/xpu-demo/XTDK) + set(XPU_TOOLCHAIN /workspace/output/XTDK-ubuntu_x86_64) get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH) endif() if(NOT IS_DIRECTORY ${XPU_TOOLCHAIN}) @@ -102,7 +102,7 @@ macro(compile_kernel COMPILE_ARGS) set(XTDK_DIR ${XPU_TOOLCHAIN}) set(CXX_DIR ${HOST_SYSROOT}) - set(XPU_CXX_FLAGS -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG ) + set(XPU_CXX_FLAGS -fforce-enable-int128 -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG ) #include path get_property(dirs DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) @@ -127,9 +127,11 @@ macro(compile_kernel COMPILE_ARGS) kernel_build/${kernel_name}.bin.o COMMAND ${CMAKE_COMMAND} -E make_directory kernel_build + COMMAND + cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu COMMAND ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} - -I. -o kernel_build/${kernel_name}.bin.o.sec ${kernel_path}/${kernel_name}.xpu + -I. -o kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.xpu --xpu-device-only -c -v COMMAND ${XTDK_DIR}/bin/xpu2-elfconv kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.bin.o ${XPU_CLANG} --sysroot=${CXX_DIR} @@ -148,9 +150,11 @@ macro(compile_kernel COMPILE_ARGS) kernel_build/${kernel_name}.host.o COMMAND ${CMAKE_COMMAND} -E make_directory kernel_build + COMMAND + cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu COMMAND ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} - -I. -o kernel_build/${kernel_name}.host.o ${kernel_path}/${kernel_name}.xpu + -I. -o kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.xpu --xpu-host-only -c -v WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} @@ -185,7 +189,7 @@ macro(xpu_add_library TARGET_NAME) # Distinguish .xpu file from other files foreach(cur_xpu_src IN LISTS xpu_srcs_lists) get_filename_component(language_type_name ${cur_xpu_src} EXT) - if(${language_type_name} STREQUAL ".xpu") + if(${language_type_name} STREQUAL ".kps") list(APPEND xpu_kernel_lists ${cur_xpu_src}) else() list(APPEND cc_kernel_lists ${cur_xpu_src}) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.kps b/paddle/fluid/operators/elementwise/elementwise_add_op.kps new file mode 100644 index 00000000000..a3fea0d7b3d --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.kps @@ -0,0 +1,188 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + +#ifdef PADDLE_WITH_XPU_KP +#include // NOLINT +#include "xpu/kernel/cluster_header.h" // NOLINT +#include "xpu/kernel/debug.h" // NOLINT +#include "xpu/kernel/math.h" // NOLINT + +#include +#include +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +#include "paddle/fluid/platform/device/device_wrapper.h" + +namespace paddle { +namespace operators { + +template +class ElementwiseAddXPUKPKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + const auto& xpu_ctx = + ctx.template device_context(); + paddle::operators::LaunchElementwiseCudaKernel, 1>( + xpu_ctx, ins, &outs, axis, kps::AddFunctor()); + } +}; + +static std::vector get_rdims(const std::vector& xdims, + const std::vector& ydims) { + std::vector rdims; + for (size_t i = 0; i < xdims.size(); i++) { + if (xdims[i] != ydims[i]) { + rdims.push_back(i); + } + } + return rdims; +} + +template +class ElementwiseAddGradXPUKPKernel : public ElemwiseGradKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dz = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + const framework::DDim& x_dims = x->dims(); + const framework::DDim& y_dims = y->dims(); + const framework::DDim& dz_dims = dz->dims(); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + PADDLE_ENFORCE_GE( + axis, 0, + platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT( + axis, max_dim, + platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", max_dim, + axis)); + + std::vector x_dims_vec(max_dim, 1); + std::vector y_dims_vec(max_dim, 1); + std::vector z_dims_vec(max_dim, 1); + if (x_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + x_dims_vec[i] = x_dims[i]; + } + } else { + for (int i = 0; i < x_dims.size(); i++) { + x_dims_vec[i + axis] = x_dims[i]; + } + } + + if (y_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + y_dims_vec[i] = y_dims[i]; + } + } else { + for (int i = 0; i < y_dims.size(); i++) { + y_dims_vec[i + axis] = y_dims[i]; + } + } + + for (int i = 0; i < max_dim; i++) { + z_dims_vec[i] = dz_dims[i]; + } + std::vector rdims_for_x; + std::vector rdims_for_y; + rdims_for_x = get_rdims(x_dims_vec, z_dims_vec); + rdims_for_y = get_rdims(y_dims_vec, z_dims_vec); + const T* dz_data = dz->data(); + auto& dev_ctx = + ctx.template device_context(); + + if (dx != nullptr) { + T* dx_data = dx->mutable_data(ctx.GetPlace()); + if (rdims_for_x.size() == 0) { + if (dx_data != dz_data) { + framework::TensorCopy( + *dz, ctx.GetPlace(), + ctx.template device_context(), dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dz, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(*dz)) { + dx->clear(); + dx->mutable_data(x->dims(), ctx.GetPlace()); + } + + int ret = xpu::reduce_sum( + dev_ctx.x_context(), reinterpret_cast(dz_data), + reinterpret_cast(dx_data), z_dims_vec, rdims_for_x); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); + } + } + + if (dy != nullptr) { + T* dy_data = dy->mutable_data(ctx.GetPlace()); + if (rdims_for_y.size() == 0) { + if (dy_data != dz_data) { + framework::TensorCopy( + *dz, ctx.GetPlace(), + ctx.template device_context(), dy); + } + } else { + int ret = xpu::reduce_sum( + dev_ctx.x_context(), reinterpret_cast(dz_data), + reinterpret_cast(dy_data), z_dims_vec, rdims_for_y); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(elementwise_add, KP, plat::XPUPlace, + ops::ElementwiseAddXPUKPKernel); + +REGISTER_OP_KERNEL(elementwise_add_grad, KP, plat::XPUPlace, + ops::ElementwiseAddGradXPUKPKernel); + +#endif // PADDLE_WITH_XPU_KP diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 17288b354a2..e9124dfc1f8 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -74,7 +74,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device/device_ext.h" #include "paddle/fluid/platform/device/stream.h" + +#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__) #include "unsupported/Eigen/CXX11/Tensor" +#endif namespace Eigen { struct DefaultDevice; diff --git a/paddle/phi/core/hostdevice.h b/paddle/phi/core/hostdevice.h index 08fe3125287..0869df14323 100644 --- a/paddle/phi/core/hostdevice.h +++ b/paddle/phi/core/hostdevice.h @@ -18,14 +18,14 @@ #include #endif -#ifdef __xpu_kp__ +#if defined(__xpu__) #include #include "xpu/kernel/cluster_header.h" #include "xpu/kernel/debug.h" #include "xpu/kernel/math.h" #endif -#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__)) +#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu__)) #define HOSTDEVICE __host__ __device__ #define DEVICE __device__ #define HOST __host__ diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index be57b8630f8..84a36b849af 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_base.h" -#if defined(__NVCC__) || defined(__HIPCC__) +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) namespace kps = phi::kps; @@ -122,7 +122,7 @@ struct DimensionsTransform { explicit DimensionsTransform(const std::vector &ins, const phi::DDim &dims, int axis) { - const int N = max(static_cast(ins.size()), 2); + const int N = std::max(static_cast(ins.size()), 2); dim_size = dims.size(); out_dims = phi::vectorize(dims); in_dims.resize(N); @@ -183,7 +183,7 @@ struct DimensionsTransform { } }; -#if defined(__NVCC__) || defined(__HIPCC__) +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template __device__ __forceinline__ void LoadData( @@ -268,7 +268,7 @@ __global__ void VectorizedBroadcastKernel( int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; -#ifdef PADDLE_WITH_XPU2 +#ifdef PADDLE_WITH_XPU_KP for (; block_offset < main_offset; block_offset += stride) { VectorizedBroadcastKernelImpl outs_data; for (int i = 0; i < NumOuts; ++i) { - outs_data[i] = ctx.Alloc((*outs)[i]); + outs_data[i] = (_ptr_ OutT *)(ctx.Alloc((*outs)[i])); } for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); - ins_data[i] = (_ptr_ InT *)(ins[i]->data()); + ins_data[i] = (const _ptr_ InT *)(ins[i]->data()); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} @@ -363,7 +363,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, } } -#ifdef PADDLE_WITH_XPU2 +#ifdef PADDLE_WITH_XPU_KP const int threads = 64; const int blocks = 8; int main_offset = (numel / (VecSize * threads)) * VecSize * threads; diff --git a/paddle/phi/kernels/funcs/eigen/extensions.h b/paddle/phi/kernels/funcs/eigen/extensions.h index 5fc8f76d988..fbb9d8e3d2e 100644 --- a/paddle/phi/kernels/funcs/eigen/extensions.h +++ b/paddle/phi/kernels/funcs/eigen/extensions.h @@ -14,6 +14,8 @@ #pragma once +#ifndef __xpu__ + #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" @@ -435,3 +437,5 @@ HOSTDEVICE inline float16 maxi(const float16& a, const float16& b) { } // namespace numext } // namespace Eigen + +#endif // __xpu__ diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 9a429dfaaf9..47f1593a11e 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -21,12 +21,13 @@ limitations under the License. */ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" -#if defined(__NVCC__) || defined(__HIPCC__) +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" +#define HOSTDEVICE __host__ __device__ namespace kps = phi::kps; #endif @@ -436,7 +437,7 @@ inline void ElementwiseGradPreProcess(const DenseTensor &dout, } } -#if defined(__NVCC__) || defined(__HIPCC__) +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) // static unroller template