未验证 提交 8d340ee1 编写于 作者: B Bo Zhang 提交者: GitHub

Fix xpu2 kp compile error (#53548)

上级 727fa27d
......@@ -110,10 +110,29 @@ struct BroadcastDataLoader {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
#ifdef PADDLE_WITH_XPU_KP
kps::Init<Type, ArgsT, Index, VecSize>(
args, static_cast<Type>(1.0f), read_lens);
if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]),
block_offset,
configs[Index],
numel,
read_lens);
} else {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
num,
read_lens);
}
#else
kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
......@@ -133,6 +152,7 @@ struct BroadcastDataLoader {
num,
VecSize);
}
#endif
}
};
......@@ -148,7 +168,8 @@ struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
......@@ -173,7 +194,8 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
using VecType = phi::kps::details::VectorType<Type, VecSize>;
VecType vec_temp;
......@@ -269,6 +291,10 @@ __device__ void VectorizedBroadcastKernelImpl(
__simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#ifdef PADDLE_WITH_XPU_KP
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
#else
if (LoadType == kBroadcast) {
uint32_t index_bc[Arity][VecSize] = {0};
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
......@@ -291,9 +317,9 @@ __device__ void VectorizedBroadcastKernelImpl(
Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
} else {
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel);
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
}
#endif
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
......
......@@ -1211,6 +1211,65 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T _global_ptr_* src,
int block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
int thread_offset = block_offset + core_id() * read_lens;
__local__ T in_temp[NX];
if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T>(in_temp, src, thread_offset, config, read_lens);
} else {
ReadDataBcCanNotCmp<T, IsBoundary>(
in_temp, src, thread_offset, config, total_num_output, read_lens);
}
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
std::get<Index>(dst[idx]) = in_temp[idx];
}
}
/**
* @brief Initialize register with data index.
*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册