diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index d0f3fba392a86e6555d4bc316833326bb99e7ec7..e754ce3bf49e4659f885e9d94a116bc98ef0aa26 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -110,10 +110,29 @@ struct BroadcastDataLoader { const Array3 &use_broadcast, const int block_offset, const int num, - const uint32_t numel) { + const uint32_t numel, + int read_lens) { using Type = std::tuple_element_t; +#ifdef PADDLE_WITH_XPU_KP + kps::Init( + args, static_cast(1.0f), read_lens); + if (use_broadcast[Index]) { + kps::ReadDataBc( + args, + reinterpret_cast(ins[Index]), + block_offset, + configs[Index], + numel, + read_lens); + } else { + kps::ReadData( + args, + reinterpret_cast(ins[Index]) + block_offset, + num, + read_lens); + } +#else kps::Init(args, static_cast(1.0f)); - if (use_broadcast[Index]) { kps::ReadDataBc( args, @@ -133,6 +152,7 @@ struct BroadcastDataLoader { num, VecSize); } +#endif } }; @@ -148,7 +168,8 @@ struct BroadcastDataLoader { const Array3 &use_broadcast, const int block_offset, const int num, - const uint32_t numel) { + const uint32_t numel, + int read_lens) { using Type = std::tuple_element_t; int thread_offset = threadIdx.x * VecSize + block_offset; #pragma unroll @@ -173,7 +194,8 @@ struct BroadcastDataLoader { const Array3 &use_broadcast, const int block_offset, const int num, - const uint32_t numel) { + const uint32_t numel, + int read_lens) { using Type = std::tuple_element_t; using VecType = phi::kps::details::VectorType; VecType vec_temp; @@ -269,6 +291,10 @@ __device__ void VectorizedBroadcastKernelImpl( __simd__ ArgsT args[VecSize]; __simd__ ConditionalT result[VecSize]; +#ifdef PADDLE_WITH_XPU_KP + BcUnroller::step( + ins, args, configs, use_broadcast, block_offset, num, numel, read_lens); +#else if (LoadType == kBroadcast) { uint32_t index_bc[Arity][VecSize] = {0}; Unroller::step(args); @@ -291,9 +317,9 @@ __device__ void VectorizedBroadcastKernelImpl( Unroller::step(ins, args, index_bc); } else { BcUnroller::step( - ins, args, configs, use_broadcast, block_offset, num, numel); + ins, args, configs, use_broadcast, block_offset, num, numel, read_lens); } - +#endif SameDimsElementwisePrimitiveCaller, VecSize, Functor, diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 14a66516f5b632c2b94343645e04d7ae5165b983..7d60b573e2775c874e70e3011233cd0b87a11075 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -1211,6 +1211,65 @@ __device__ __inline__ void ReadDataBc(T* dst, } } +/** + * @brief Read 1D data from global memory to register with broadcast form. + * The difference from the above function is that it supports different data + * types of inputs. + * @template paraments + * T: The type of data stored in the global memory. + * NX: The number of data continuously loaded by each thread. + * NY: The number of data rows loaded by each thread, only NY = 1 was supported. + * core_id() is used as the index. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The original input data pointer of kernel. + * block_offset: The data offset of this block, core_num() * blockIdx.x * NX; + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * read_lens: The number of data continuously loaded by each thread. + * total_num_output: Total number of original output. + */ +template +__device__ __forceinline__ void ReadDataBc( + ArgsT* dst, + const T _global_ptr_* src, + int block_offset, + const details::BroadcastConfig& config, + int total_num_output, + int read_lens = NX) { + int thread_offset = block_offset + core_id() * read_lens; + __local__ T in_temp[NX]; + + if (config.cmp_type == details::OptType::MNK_M1K) { + ReadDataBcM1kMnk(in_temp, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::N_1) { + ReadDataBc1N(in_temp, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MN_M) { + ReadDataBcM1Mn(in_temp, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MN_N) { + ReadDataBc1NMn(in_temp, src, thread_offset, config, read_lens); + } else if (config.cmp_type == details::OptType::MNK_1N1) { + ReadDataBc1N1Mnk(in_temp, src, thread_offset, config, read_lens); + } else { + ReadDataBcCanNotCmp( + in_temp, src, thread_offset, config, total_num_output, read_lens); + } +#pragma unroll + for (int idx = 0; idx < read_lens; ++idx) { + std::get(dst[idx]) = in_temp[idx]; + } +} + /** * @brief Initialize register with data index. *