未验证 提交 3eaf8d2c 编写于 作者: N niuliling123 提交者: GitHub

Modified Kernel Primitive API and elementwise for xpu2 #38688

上级 2bed9b9c
...@@ -25,8 +25,7 @@ namespace kps = paddle::operators::kernel_primitives; ...@@ -25,8 +25,7 @@ namespace kps = paddle::operators::kernel_primitives;
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const KPDevice &ctx, const std::vector<const framework::Tensor *> &ins,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) { std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs; std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs; std::vector<pten::DenseTensor *> pt_outputs;
...@@ -58,8 +57,7 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -58,8 +57,7 @@ void LaunchBroadcastElementwiseCudaKernel(
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchElementwiseCudaKernel( void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx, const KPDevice &ctx, const std::vector<const framework::Tensor *> &ins,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) { std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs; std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs; std::vector<pten::DenseTensor *> pt_outputs;
...@@ -85,7 +83,7 @@ void LaunchElementwiseCudaKernel( ...@@ -85,7 +83,7 @@ void LaunchElementwiseCudaKernel(
pt_outputs.push_back(pt_outputs_tmp[i].get()); pt_outputs.push_back(pt_outputs_tmp[i].get());
} }
pten::LaunchElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>( pten::LaunchElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, pt_inputs, &pt_outputs, axis, func); ctx, pt_inputs, &pt_outputs, axis, func);
} }
} // namespace operators } // namespace operators
......
...@@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType; ...@@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType;
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel( void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const KPDevice &ctx, const std::vector<const framework::Tensor *> &ins,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs; std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs; std::vector<pten::DenseTensor *> pt_outputs;
......
...@@ -32,42 +32,50 @@ struct alignas(sizeof(T) * VecSize) VectorType { ...@@ -32,42 +32,50 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* index of the output data. if input or output shape is [dim0, dim1] then dims * index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0]. * must be [dim1, dim0].
*/ */
#pragma pack(4)
template <int kDims> template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
uint32_t stride_in[framework::DDim::kMaxRank]; int strides_in[framework::DDim::kMaxRank];
uint32_t stride_out[framework::DDim::kMaxRank]; int strides_out[framework::DDim::kMaxRank];
uint32_t shape_in[framework::DDim::kMaxRank]; int in_dim[framework::DDim::kMaxRank];
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
const std::vector<int64_t>& in_dims, const std::vector<int64_t>& in_dims,
int dim_size) { int dim_size) {
std::vector<uint32_t> strides_in; std::vector<int> strides_in_tmp;
std::vector<uint32_t> strides_out; std::vector<int> strides_out_tmp;
std::vector<uint32_t> shapes_in; std::vector<int> dim_tmp;
strides_in_tmp.resize(dim_size, 1);
strides_out.resize(dim_size, 1); strides_out_tmp.resize(dim_size, 1);
strides_in.resize(dim_size, 1); dim_tmp.resize(dim_size, 1);
shapes_in.resize(dim_size, 1); for (int i = 1; i < dim_size; i++) {
strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[i - 1];
for (int i = 0; i < dim_size; ++i) { strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1];
shape_in[i] = in_dims[dim_size - i - 1];
} }
for (int i = 1; i < dim_size - 1; ++i) { for (int i = 0; i < dim_size; i++) {
strides_out[dim_size - i - 1] = std::accumulate( dim_tmp[i] = in_dims[i];
out_dims.begin(), out_dims.begin() + i, 1, std::multiplies<int64_t>())
strides_in[dim_size - i - 1] =
std::accumulate(in_dims.begin(), in_dims.begin() + i, 1,
std::multiplies<int64_t>())
} }
memcpy(stride_in, strides_in.data(), kDims * sizeof(uint32_t)); memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(stride_out, strides_out.data(), kDims * sizeof(uint32_t)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(shape_in, shapes_in.data(), kDims * sizeof(uint32_t)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
}
__device__ inline int operator()(int index_output) const {
int index_src = 0;
#pragma unroll
for (int i = kDims - 1; i >= 0; --i) {
int tmp_index = (index_output / strides_out[i]);
index_output = index_output - tmp_index * strides_out[i];
index_src += (tmp_index % in_dim[i]) * strides_in[i];
}
return index_src;
} }
}; };
#pragma pack()
} // namespace details } // namespace details
...@@ -99,12 +107,12 @@ struct BroadcastConfig { ...@@ -99,12 +107,12 @@ struct BroadcastConfig {
*/ */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize, template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, __device__ __inline__ void ReadData(Ty* dst, const Tx _global_ptr_* src,
int size_nx, int size_ny, int size_nx, int size_ny, int stride_nx,
int stride_nx, int stride_ny) { int stride_ny) {
int thread_offset = core_id(); int thread_offset = core_id();
int left_size_nx = size_nx - thread_offset; int left_size_nx = size_nx - thread_offset;
__local__ T in_temp[1]; __local__ Tx in_temp[1];
// Each branch is added for better performance // Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
if (IsBoundary) { if (IsBoundary) {
...@@ -168,7 +176,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, ...@@ -168,7 +176,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src,
* init_data: Initial value. * init_data: Initial value.
*/ */
template <typename T, int NX> template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) { __device__ __inline__ void Init(T* dst, T init_data) {
#pragma unroll #pragma unroll
for (int i = 0; i < NX; i++) { for (int i = 0; i < NX; i++) {
dst[i] = init_data; dst[i] = init_data;
...@@ -197,8 +205,8 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { ...@@ -197,8 +205,8 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
* size: The current block needs to load size data continuously. * size: The current block needs to load size data continuously.
*/ */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, __device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src,
int num) { int num) {
int thread_offset = core_id() * NX; int thread_offset = core_id() * NX;
__local__ T in_temp[1]; __local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num if (IsBoundary) { // core_num() * NX > num
...@@ -241,10 +249,11 @@ __device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, ...@@ -241,10 +249,11 @@ __device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src,
*/ */
template <typename T, int NX, int NY, int BlockSize, int Rank, template <typename T, int NX, int NY, int BlockSize, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src,
T* dst, const T _global_ptr_* src, uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, int total_num_output, int stride_nx, details::BroadcastConfig<Rank> config,
int stride_ny) { int total_num_output, int stride_nx,
int stride_ny) {
uint32_t thread_offset = block_offset + core_id(); uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0; uint32_t index_src = 0;
__local__ T in_temp[1]; __local__ T in_temp[1];
...@@ -256,16 +265,11 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -256,16 +265,11 @@ __device__ __forceinline__ void ReadDataBc(
uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx; uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
index_src = 0; index_src = 0;
if (IsBoundary) { if (IsBoundary) {
if (index_output >= total_num_output) { if (index_output >= (uint32_t)total_num_output) {
break; break;
} }
} }
#pragma unroll index_src = config(index_output);
for (int i = 0; i < Rank; ++i) {
uint32_t tmp = index_output / config.stride_out[i];
index_output = index_output - tmp * config.stride_out[i];
index_src += (tmp % config.shape_in[i]) * config.stride_in[i];
}
GM2LM(src + index_src, in_temp, sizeof(T)); GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0]; dst[nx + ny * NX] = in_temp[0];
} }
...@@ -305,33 +309,34 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -305,33 +309,34 @@ __device__ __forceinline__ void ReadDataBc(
*/ */
template <typename T, int NX, int NY, int BlockSize, int Rank, template <typename T, int NX, int NY, int BlockSize, int Rank,
typename IndexCal, bool IsBoundary = false> typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce( __device__ __inline__ void ReadDataReduce(T* dst, const T _global_ptr_* src,
T* dst, const T _global_ptr_* src, int block_offset, int block_offset,
const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx, const IndexCal& index_cal,
int stride_ny, bool reduce_last_dim) { int size_nx, int size_ny,
__local__ T in_temp[1]; int stride_nx, int stride_ny,
bool reduce_last_dim) {
__local__ Tx in_temp[1];
int thread_offset = 0; int thread_offset = 0;
int left_size_nx = size_nx; int left_idx = 0;
int left_size_ny = size_ny;
if (reduce_last_dim) { if (reduce_last_dim) {
thread_offset = block_offset + core_id(); thread_offset = core_id();
left_size_nx -= thread_offset; left_idx = 0;
} else { } else {
thread_offset = block_offset + core_id(); thread_offset = 0;
left_size_ny -= thread_offset; left_idx = 0;
} }
if (NX == 1) { if (NX == 1) {
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) { if (IsBoundary) {
if (ny * stride_ny >= left_size_ny) { if (thread_offset >= size_ny) {
break; break;
} }
} }
uint32_t index_src = index_cal(thread_offset); uint32_t index_src = index_cal(thread_offset + block_offset);
GM2LM(src + index_src, in_temp, sizeof(T)); GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = in_temp[0]; dst[ny] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny; thread_offset += stride_ny;
} }
} else { } else {
...@@ -340,17 +345,16 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -340,17 +345,16 @@ __device__ __forceinline__ void ReadDataReduce(
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) { if (IsBoundary) {
if ((ny * stride_ny >= left_size_ny) || if ((thread_offset >= size_ny) ||
(nx * stride_nx >= left_size_nx)) { (left_idx + nx * stride_nx >= size_nx)) {
break; break;
} }
} }
uint32_t index_src = index_cal(thread_offset); uint32_t index_src = index_cal(thread_offset + block_offset);
GM2LM(src + index_src, in_temp, sizeof(T)); GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = in_temp[0]; dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny; thread_offset += stride_ny;
} }
thread_offset += stride_nx;
} }
} }
} }
...@@ -421,9 +425,9 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { ...@@ -421,9 +425,9 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
*/ */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize, template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, __device__ __inline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
int size_nx, int size_ny, int size_nx, int size_ny, int stride_nx,
int stride_nx, int stride_ny) { int stride_ny) {
int thread_offset = core_id(); int thread_offset = core_id();
int left_size_nx = size_nx - thread_offset; int left_size_nx = size_nx - thread_offset;
__local__ Ty in_temp[1]; __local__ Ty in_temp[1];
...@@ -433,11 +437,11 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, ...@@ -433,11 +437,11 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
if (IsBoundary) { if (IsBoundary) {
if (left_size_nx > 0) { if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]); in_temp[0] = static_cast<Ty>(src[0]);
LM2GM(in_temp, dst + thread_offset, sizeof(T)); LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
} }
} else { } else {
in_temp[0] = static_cast<Ty>(src[0]); in_temp[0] = static_cast<Ty>(src[0]);
LM2GM(in_temp, dst + thread_offset, sizeof(T)); LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
} }
} else if (NX == 1) { } else if (NX == 1) {
#pragma unroll #pragma unroll
...@@ -449,7 +453,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, ...@@ -449,7 +453,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
} }
in_temp[0] = static_cast<Ty>(src[idy]); in_temp[0] = static_cast<Ty>(src[idy]);
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(T)); LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
} }
} else if (NY == 1) { // for NY == 1 and NX != 1 } else if (NY == 1) { // for NY == 1 and NX != 1
#pragma unroll #pragma unroll
...@@ -461,7 +465,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, ...@@ -461,7 +465,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
} }
in_temp[0] = static_cast<Ty>(src[idx]); in_temp[0] = static_cast<Ty>(src[idx]);
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(T)); LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
} }
} else { // for NX != 1 and NY != 1 } else { // for NX != 1 and NY != 1
#pragma unroll #pragma unroll
...@@ -480,7 +484,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, ...@@ -480,7 +484,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
} }
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]); in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny, LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(T)); sizeof(Ty));
} }
} }
} }
...@@ -498,7 +502,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, ...@@ -498,7 +502,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src,
* init_data: The register pointer of init data, the size is NX. * init_data: The register pointer of init data, the size is NX.
*/ */
template <typename T, int NX, bool IsBoundary = false> template <typename T, int NX, bool IsBoundary = false>
__device__ __forceinline__ void Init(T* dst, T* init_data, int num) { __device__ __inline__ void Init(T* dst, T* init_data, int num) {
#pragma unroll #pragma unroll
for (int i = 0; i < NX; i++) { for (int i = 0; i < NX; i++) {
if (IsBoundary) { if (IsBoundary) {
...@@ -535,30 +539,26 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { ...@@ -535,30 +539,26 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
*/ */
template <typename T, int NX, int NY, int BlockSize, int Rank, template <typename T, int NX, int NY, int BlockSize, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src,
T* dst, const T _global_ptr_* src, uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, int total_num_output) { details::BroadcastConfig<Rank> config,
uint32_t thread_offset = block_offset + core_id() * NX; int total_num_output) {
uint32_t index_src = 0; int thread_offset = block_offset + core_id() * NX;
__local__ T in_temp[1]; int index_src = 0;
__local__ T in_temp;
#pragma unroll #pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) { for (int nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx; int index_output = thread_offset + nx;
index_src = 0; index_src = 0;
if (IsBoundary) { if (IsBoundary) {
if (index_output >= total_num_output) { if (index_output >= total_num_output) {
break; break;
} }
} }
#pragma unroll index_src = config(index_output);
for (int i = 0; i < Rank; ++i) { GM2LM(src + index_src, &in_temp, sizeof(T));
uint32_t tmp = index_output / config.stride_out[i]; dst[nx] = in_temp;
index_output = index_output - tmp * config.stride_out[i];
index_src += (tmp % config.shape_in[i]) * config.stride_in[i];
}
GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
} }
} }
......
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" #include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h" #include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h"
#define KPStream XPUStream
#define KPDevice paddle::platform::XPUDeviceContext
#define _ptr_ _global_ptr_
#define __forceinline__ __inline__
#define __restrict__
#define THREAD_ID_X core_id() #define THREAD_ID_X core_id()
#define THREAD_ID_Y 0 #define THREAD_ID_Y 0
#define THREAD_ID_Z 0 #define THREAD_ID_Z 0
...@@ -36,6 +43,12 @@ ...@@ -36,6 +43,12 @@
#else #else
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#define KPStream gpuStream_t
#define KPDevice paddle::platform::CUDADeviceContext
#define _ptr_
#define THREAD_ID_X threadIdx.x #define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y #define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Z threadIdx.z #define THREAD_ID_Z threadIdx.z
......
...@@ -17,7 +17,14 @@ ...@@ -17,7 +17,14 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#if (defined(__CUDACC__) || defined(__HIPCC__)) #ifdef __xpu_kp__
#include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#endif
#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__))
#define HOSTDEVICE __host__ __device__ #define HOSTDEVICE __host__ __device__
#define DEVICE __device__ #define DEVICE __device__
#define HOST __host__ #define HOST __host__
......
...@@ -86,7 +86,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> { ...@@ -86,7 +86,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts> template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller { struct ElementwiseWriteDataCaller {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, NumOuts> outs, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
ConditionalT<OutT, NumOuts> src[VecSize], ConditionalT<OutT, NumOuts> src[VecSize],
int block_offset, int block_offset,
int num) { int num) {
...@@ -109,7 +109,7 @@ struct ElementwiseWriteDataCaller { ...@@ -109,7 +109,7 @@ struct ElementwiseWriteDataCaller {
template <typename OutT, int VecSize, bool IsBoundary> template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> { struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, 1> outs, paddle::framework::Array<_ptr_ OutT *, 1> outs,
OutT src[VecSize], OutT src[VecSize],
int block_offset, int block_offset,
int num) { int num) {
...@@ -126,8 +126,8 @@ template <typename InT, ...@@ -126,8 +126,8 @@ template <typename InT,
int VecSize, int VecSize,
bool IsBoundary> bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl( __device__ void VectorizedElementwiseKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &in, const paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> &in,
paddle::framework::Array<OutT *, NumOuts> outs, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
int num, int num,
int data_offset, int data_offset,
Functor func) { Functor func) {
...@@ -161,8 +161,8 @@ template <typename InT, ...@@ -161,8 +161,8 @@ template <typename InT,
int NumOuts, int NumOuts,
int VecSize> int VecSize>
__global__ void VectorizedElementwiseKernel( __global__ void VectorizedElementwiseKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins, paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
paddle::framework::Array<OutT *, NumOuts> outs, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
int size, int size,
int main_offset, int main_offset,
Functor func) { Functor func) {
...@@ -212,17 +212,13 @@ template <typename InT, ...@@ -212,17 +212,13 @@ template <typename InT,
int Arity, int Arity,
int NumOuts, int NumOuts,
int VecSize> int VecSize>
void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, void ElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func) { Functor func) {
auto numel = ins[0]->numel(); auto numel = ins[0]->numel();
int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
int grid_size = paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
paddle::framework::Array<OutT *, NumOuts> outs_data;
for (int i = 0; i < Arity; ++i) { for (int i = 0; i < Arity; ++i) {
ins_data[i] = ins[i]->data<InT>(); ins_data[i] = ins[i]->data<InT>();
...@@ -231,8 +227,9 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -231,8 +227,9 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
outs_data[i] = (*outs)[i]->mutable_data<OutT>(); outs_data[i] = (*outs)[i]->mutable_data<OutT>();
} }
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
block_size = 128; int block_size = 64;
grid_size = 8; int grid_size = 8;
auto stream = ctx.x_context()->xpu_stream;
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT, VectorizedElementwiseKernel<InT,
OutT, OutT,
...@@ -242,7 +239,11 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -242,7 +239,11 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
VecSize><<<grid_size, block_size, 0, stream>>>( VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, outs_data, numel, main_offset, func); ins_data, outs_data, numel, main_offset, func);
#else #else
int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
auto stream = ctx.stream();
VectorizedElementwiseKernel<InT, VectorizedElementwiseKernel<InT,
OutT, OutT,
Functor, Functor,
...@@ -259,7 +260,7 @@ template <ElementwiseType ET, ...@@ -259,7 +260,7 @@ template <ElementwiseType ET,
typename Functor, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel( void LaunchSameDimsElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx, const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func) { Functor func) {
...@@ -471,12 +472,12 @@ struct DimensionsTransform { ...@@ -471,12 +472,12 @@ struct DimensionsTransform {
template <typename T, int VecSize, int Rank, bool IsBoundary = false> template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData( __device__ __forceinline__ void LoadData(
T *dst, T *dst,
const T *__restrict__ src, const _ptr_ T *src,
uint32_t block_offset, uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config, const kps::details::BroadcastConfig<Rank> &config,
int numel, int numel,
int num, int num,
bool need_broadcast) { int need_broadcast) {
// numel : whole num of output // numel : whole num of output
// num: how many data will be deal with in this time // num: how many data will be deal with in this time
if (need_broadcast) { if (need_broadcast) {
...@@ -496,9 +497,9 @@ template <typename InT, ...@@ -496,9 +497,9 @@ template <typename InT,
int Rank, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ void ElementwiseBroadcastKernelImpl( __device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins, const paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> &ins,
paddle::framework::Array<OutT *, NumOuts> outs, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
const paddle::framework::Array<bool, Arity> &use_broadcast, const paddle::framework::Array<int, Arity> &use_broadcast,
uint32_t numel, uint32_t numel,
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
&configs, &configs,
...@@ -540,9 +541,9 @@ template <typename InT, ...@@ -540,9 +541,9 @@ template <typename InT,
int VecSize, int VecSize,
int Rank> int Rank>
__global__ void ElementwiseBroadcastKernel( __global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins, paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
paddle::framework::Array<OutT *, NumOuts> outs, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
paddle::framework::Array<bool, Arity> use_broadcast, paddle::framework::Array<int, Arity> use_broadcast,
uint32_t numel, uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
configs, configs,
...@@ -570,7 +571,8 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -570,7 +571,8 @@ __global__ void ElementwiseBroadcastKernel(
block_offset, block_offset,
func); func);
} }
if (block_offset < numel) { int num = numel - block_offset;
if (num > 0) {
ElementwiseBroadcastKernelImpl<InT, ElementwiseBroadcastKernelImpl<InT,
OutT, OutT,
Functor, Functor,
...@@ -579,7 +581,7 @@ __global__ void ElementwiseBroadcastKernel( ...@@ -579,7 +581,7 @@ __global__ void ElementwiseBroadcastKernel(
VecSize, VecSize,
Rank, Rank,
true>( true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); ins, outs, use_broadcast, numel, configs, num, block_offset, func);
} }
#else #else
if (block_offset < main_offset) { if (block_offset < main_offset) {
...@@ -619,23 +621,16 @@ template <typename InT, ...@@ -619,23 +621,16 @@ template <typename InT,
int NumOuts, int NumOuts,
int VecSize, int VecSize,
int Rank> int Rank>
void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, void LaunchKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func, Functor func,
DimensionsTransform merge_dims) { DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel(); int numel = (*outs)[0]->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs; paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
paddle::framework::Array<bool, Arity> use_broadcast; paddle::framework::Array<int, Arity> use_broadcast;
paddle::framework::Array<const InT *__restrict__, Arity> ins_data; paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
paddle::framework::Array<OutT *, NumOuts> outs_data; paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) { for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>(); outs_data[i] = (*outs)[i]->mutable_data<OutT>();
...@@ -643,7 +638,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -643,7 +638,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>(); ins_data[i] = (_ptr_ InT *)(ins[i]->data<InT>());
if (use_broadcast[i]) { if (use_broadcast[i]) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
...@@ -654,10 +649,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -654,10 +649,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
} }
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
threads = 128; const int threads = 64;
blocks = 8; const int blocks = 8;
main_offset = (numel / (VecSize * threads)) * VecSize * threads; int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
tail_tid = numel % (VecSize * threads); int tail_tid = numel % (VecSize * threads);
auto stream = ctx.x_context()->xpu_stream;
ElementwiseBroadcastKernel<InT, ElementwiseBroadcastKernel<InT,
OutT, OutT,
Functor, Functor,
...@@ -673,6 +669,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, ...@@ -673,6 +669,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
tail_tid, tail_tid,
func); func);
#else #else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
ElementwiseBroadcastKernel<InT, ElementwiseBroadcastKernel<InT,
OutT, OutT,
Functor, Functor,
...@@ -698,7 +699,7 @@ template <typename InT, ...@@ -698,7 +699,7 @@ template <typename InT,
int NumOuts, int NumOuts,
int VecSize> int VecSize>
void LaunchBroadcastKernelForDifferentVecSize( void LaunchBroadcastKernelForDifferentVecSize(
const paddle::platform::CUDADeviceContext &ctx, const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, int axis,
...@@ -737,7 +738,7 @@ template <ElementwiseType ET, ...@@ -737,7 +738,7 @@ template <ElementwiseType ET,
typename Functor, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx, const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, int axis,
...@@ -835,12 +836,11 @@ template <ElementwiseType ET, ...@@ -835,12 +836,11 @@ template <ElementwiseType ET,
typename OutT, typename OutT,
typename Functor, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchElementwiseCudaKernel( void LaunchElementwiseCudaKernel(const KPDevice &ctx,
const paddle::platform::CUDADeviceContext &cuda_ctx, const std::vector<const DenseTensor *> &ins,
const std::vector<const DenseTensor *> &ins, std::vector<DenseTensor *> *outs,
std::vector<DenseTensor *> *outs, int axis,
int axis, Functor func) {
Functor func) {
std::vector<int> dims_size; std::vector<int> dims_size;
bool no_broadcast_flag = true; bool no_broadcast_flag = true;
for (auto *in : ins) { for (auto *in : ins) {
...@@ -849,14 +849,14 @@ void LaunchElementwiseCudaKernel( ...@@ -849,14 +849,14 @@ void LaunchElementwiseCudaKernel(
} }
if (no_broadcast_flag) { if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>( LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, ins, outs, func); ctx, ins, outs, func);
} else { } else {
axis = axis == -1 axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) - ? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end()) *std::min_element(dims_size.begin(), dims_size.end())
: axis; : axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>( LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, ins, outs, axis, func); ctx, ins, outs, axis, func);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册