未验证 提交 17d8a5e0 编写于 作者: F Feng Xing 提交者: GitHub

Separate include and macro in kp top level file (#40202)

* format softmax forward

* seperate include and macro to two if-else
上级 88c03071
......@@ -13,7 +13,10 @@
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/primitive/helper_primitives.h"
// macro
#ifdef PADDLE_WITH_XPU_KP
#define KPStream XPUStream
......@@ -22,11 +25,6 @@
#define __forceinline__ __inline__
#define __restrict__
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#define THREAD_ID_X core_id()
#define THREAD_ID_Y 0
#define THREAD_ID_Z 0
......@@ -42,11 +40,8 @@
#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#else
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
......@@ -67,4 +62,22 @@
#define GRID_NUM_X gridDim.x
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#endif
// include file
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#else
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册