未验证 提交 9c5d5665 编写于 作者: N niuliling123 提交者: GitHub

Modify reduce_op.op.h for xpu2 with kernel primitive api (#36904)

* Modify reduce_op.op.h for xpu2 with kernel primitive api
上级 d08753df
......@@ -360,12 +360,12 @@ __device__ __forceinline__ void ReadDataBc(
* reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension.
*/
template <typename T, int NX, int NY, int BlockSize, int Rank,
typename IndexCal, bool IsBoundary = false>
template <typename Tx, typename Ty, int NX, int NY, int BlockSize, int Rank,
typename IndexCal, typename Functor, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
T* dst, const T* __restrict__ src, int block_offset,
Ty* dst, const Tx* __restrict__ src, int block_offset,
const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
int stride_ny, bool reduce_last_dim) {
int stride_ny, Functor func, bool reduce_last_dim) {
int thread_offset = 0;
int left_idx = 0;
if (reduce_last_dim) {
......@@ -385,7 +385,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
dst[ny] = src[index_src];
dst[ny] = static_cast<Ty>(func(src[index_src]));
thread_offset += stride_ny;
}
} else {
......@@ -400,7 +400,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
dst[nx + ny * NX] = src[index_src];
dst[nx + ny * NX] = static_cast<Ty>(func(src[index_src]));
thread_offset += stride_ny;
}
}
......
......@@ -17,64 +17,49 @@
namespace paddle {
namespace operators {
namespace kernel_primitives {
namespace details {
static __device__ __forceinline__ platform::float16 ExpFunctor(
platform::float16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float ExpFunctor(float x) { return expf(x); }
static __device__ __forceinline__ double ExpFunctor(double x) { return exp(x); }
static __device__ __forceinline__ platform::float16 LogFunctor(
platform::float16 x) {
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); }
static __device__ __forceinline__ double LogFunctor(double x) { return log(x); }
#ifdef PADDLE_WITH_XPU2
struct dim3 {
int x;
int y;
int z;
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template <typename Tx, typename Ty = Tx>
struct ExpLogitTransformer {
HOSTDEVICE explicit inline ExpLogitTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(details::ExpFunctor(x[0]));
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(details::ExpFunctor(x));
explicit inline dim3(int split_x, int split_y = 1, int split_z = 1) {
x = split_x;
y = split_y;
z = split_z;
}
};
#endif
// Post processing function for sum, max, min, prod, any
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
struct DimConfig {
int split_num_x;
int split_num_y;
int split_num_z;
int deal_size_x;
int deal_size_y;
int deal_size_z;
int rem_x;
int rem_y;
int rem_z;
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]);
HOSTDEVICE explicit inline DimConfig(int split_x, int split_y, int split_z,
int size_x, int size_y, int size_z) {
split_num_x = split_x;
split_num_y = split_y;
split_num_z = split_z;
deal_size_x = size_x;
deal_size_y = size_y;
deal_size_z = size_z;
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
HOSTDEVICE void SetRem(int rem_nx, int rem_ny, int rem_nz) {
rem_x = rem_nx;
rem_y = rem_ny;
rem_z = rem_nz;
}
};
// Post processing function for mean
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T* x) const { return x[0] * n_inv; }
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
} // namespace details
} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
......@@ -13,11 +13,45 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
#define THREAD_ID_X core_id()
#define THREAD_ID_Y 0
#define THREAD_ID_Z 0
#define BLOCK_NUM_X core_num()
#define BLOCK_NUM_Y 0
#define BLOCK_NUM_Z 0
#define BLOCK_ID_X cluster_id()
#define BLOCK_ID_Y 0
#define BLOCK_ID_Z 0
#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#else
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Z threadIdx.z
#define BLOCK_NUM_X blockDim.x
#define BLOCK_NUM_Y blockDim.y
#define BLOCK_NUM_Z blockDim.z
#define BLOCK_ID_X blockIdx.x
#define BLOCK_ID_Y blockIdx.y
#define BLOCK_ID_Z blockIdx.z
#define GRID_NUM_X gridDim.x
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#endif
namespace paddle {
namespace operators {
......
......@@ -130,7 +130,7 @@ __global__ void AddMarginToPositiveLogitsKernel(
template <typename Tx, typename Ty = Tx>
struct ExpAndSum {
using Transformer = kpds::ExpLogitTransformer<Tx>;
using Transformer = kps::ExpFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
......@@ -159,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const int64_t N, const int64_t D) {
CUDA_KERNEL_LOOP(i, N * D) {
auto row = i / D;
logits[i] -= kpds::LogFunctor(logits_sum_per_row[row]);
logits[i] -= kps::details::Log(logits_sum_per_row[row]);
}
}
......@@ -174,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if ((col + start_index) == labels[row]) {
auto softmax = log_softmax[i];
loss[row] = -softmax;
log_softmax[i] = kpds::ExpFunctor(softmax);
log_softmax[i] = kps::details::Exp(softmax);
} else {
log_softmax[i] = kpds::ExpFunctor(log_softmax[i]);
log_softmax[i] = kps::details::Exp(log_softmax[i]);
}
}
}
......
......@@ -24,11 +24,11 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace kpds = paddle::operators::kernel_primitives::details;
namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx>
struct CustomMin {
using Transformer = kpds::IdentityFunctor<Tx>;
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::max());
......@@ -41,7 +41,7 @@ struct CustomMin {
template <typename Tx, typename Ty = Tx>
struct CustomMax {
using Transformer = kpds::IdentityFunctor<Tx>;
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::lowest());
......@@ -55,7 +55,7 @@ struct CustomMax {
// for cub::Reduce
template <typename Tx, typename Ty = Tx>
struct CustomSum {
using Transformer = kpds::IdentityFunctor<Tx, Ty>;
using Transformer = kps::IdentityFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
......@@ -66,7 +66,7 @@ struct CustomSum {
template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kpds::DivideFunctor<Tx>;
using Transformer = kps::DivideFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
......@@ -77,7 +77,7 @@ struct CustomMean {
template <typename Tx, typename Ty = Tx>
struct CustomMul {
using Transformer = kpds::IdentityFunctor<Tx>;
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(1.0f); }
......@@ -88,7 +88,7 @@ struct CustomMul {
template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr {
using Transformer = kpds::IdentityFunctor<Tx>;
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(false); }
......@@ -99,7 +99,7 @@ struct CustomLogicalOr {
template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd {
using Transformer = kpds::IdentityFunctor<Tx>;
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(true); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册