From 17d8a5e0c270206218891d6f41ffda3271f26c4a Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Fri, 11 Mar 2022 18:20:46 +0800 Subject: [PATCH] Separate include and macro in kp top level file (#40202) * format softmax forward * seperate include and macro to two if-else --- .../phi/kernels/primitive/kernel_primitives.h | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/primitive/kernel_primitives.h b/paddle/phi/kernels/primitive/kernel_primitives.h index 830bc1972c..b5a1e88acc 100644 --- a/paddle/phi/kernels/primitive/kernel_primitives.h +++ b/paddle/phi/kernels/primitive/kernel_primitives.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once + #include "paddle/phi/kernels/primitive/helper_primitives.h" + +// macro #ifdef PADDLE_WITH_XPU_KP #define KPStream XPUStream @@ -22,11 +25,6 @@ #define __forceinline__ __inline__ #define __restrict__ -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h" -#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h" -#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h" - #define THREAD_ID_X core_id() #define THREAD_ID_Y 0 #define THREAD_ID_Z 0 @@ -42,11 +40,8 @@ #define GRID_NUM_X cluster_num() #define GRID_NUM_Y 0 #define GRID_NUM_Z 0 + #else -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/primitive/compute_primitives.h" -#include "paddle/phi/kernels/primitive/datamover_primitives.h" -#include "paddle/phi/kernels/primitive/functor_primitives.h" #define KPStream gpuStream_t #define KPDevice phi::GPUContext @@ -67,4 +62,22 @@ #define GRID_NUM_X gridDim.x #define GRID_NUM_Y gridDim.y #define GRID_NUM_Z gridDim.z + +#endif + +// include file +#ifdef PADDLE_WITH_XPU_KP + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h" +#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h" +#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h" + +#else + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" +#include "paddle/phi/kernels/primitive/datamover_primitives.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + #endif -- GitLab