未验证 提交 17d8a5e0 编写于 作者: F Feng Xing 提交者: GitHub

Separate include and macro in kp top level file (#40202)

* format softmax forward

* seperate include and macro to two if-else
上级 88c03071
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/phi/kernels/primitive/helper_primitives.h" #include "paddle/phi/kernels/primitive/helper_primitives.h"
// macro
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
#define KPStream XPUStream #define KPStream XPUStream
...@@ -22,11 +25,6 @@ ...@@ -22,11 +25,6 @@
#define __forceinline__ __inline__ #define __forceinline__ __inline__
#define __restrict__ #define __restrict__
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#define THREAD_ID_X core_id() #define THREAD_ID_X core_id()
#define THREAD_ID_Y 0 #define THREAD_ID_Y 0
#define THREAD_ID_Z 0 #define THREAD_ID_Z 0
...@@ -42,11 +40,8 @@ ...@@ -42,11 +40,8 @@
#define GRID_NUM_X cluster_num() #define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0 #define GRID_NUM_Y 0
#define GRID_NUM_Z 0 #define GRID_NUM_Z 0
#else #else
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#define KPStream gpuStream_t #define KPStream gpuStream_t
#define KPDevice phi::GPUContext #define KPDevice phi::GPUContext
...@@ -67,4 +62,22 @@ ...@@ -67,4 +62,22 @@
#define GRID_NUM_X gridDim.x #define GRID_NUM_X gridDim.x
#define GRID_NUM_Y gridDim.y #define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z #define GRID_NUM_Z gridDim.z
#endif
// include file
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#else
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册