diff --git a/paddle/phi/kernels/primitive/kernel_primitives.h b/paddle/phi/kernels/primitive/kernel_primitives.h index 830bc1972c49fe8c447e9a13f874841d36a12f2d..b5a1e88acc32b1b101a6f81b750be1c669236a1a 100644 --- a/paddle/phi/kernels/primitive/kernel_primitives.h +++ b/paddle/phi/kernels/primitive/kernel_primitives.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once + #include "paddle/phi/kernels/primitive/helper_primitives.h" + +// macro #ifdef PADDLE_WITH_XPU_KP #define KPStream XPUStream @@ -22,11 +25,6 @@ #define __forceinline__ __inline__ #define __restrict__ -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h" -#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h" -#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h" - #define THREAD_ID_X core_id() #define THREAD_ID_Y 0 #define THREAD_ID_Z 0 @@ -42,11 +40,8 @@ #define GRID_NUM_X cluster_num() #define GRID_NUM_Y 0 #define GRID_NUM_Z 0 + #else -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/primitive/compute_primitives.h" -#include "paddle/phi/kernels/primitive/datamover_primitives.h" -#include "paddle/phi/kernels/primitive/functor_primitives.h" #define KPStream gpuStream_t #define KPDevice phi::GPUContext @@ -67,4 +62,22 @@ #define GRID_NUM_X gridDim.x #define GRID_NUM_Y gridDim.y #define GRID_NUM_Z gridDim.z + +#endif + +// include file +#ifdef PADDLE_WITH_XPU_KP + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h" +#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h" +#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h" + +#else + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" +#include "paddle/phi/kernels/primitive/datamover_primitives.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + #endif