未验证 提交 12df57fb 编写于 作者: N niuliling123 提交者: GitHub

add ElementwiseTernary, Reduce, ReadDataStride (#35075)

* add ElementwiseTernary, Reduce, ReadDataStride
上级 d9afa839
......@@ -21,7 +21,8 @@
#include <hip/hip_fp16.h>
#endif
#include <algorithm>
// #include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -29,6 +30,16 @@ namespace operators {
namespace kernel_primitives {
namespace details {
#ifdef __HIPCC__
constexpr int kMaxThread = 256;
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kWarpSize = 32;
#endif
enum ReduceMode { kGlobalMode, kLocalMode };
template <typename T>
class MPTypeTrait {
public:
......@@ -41,37 +52,98 @@ class MPTypeTrait<platform::float16> {
using Type = float;
};
} // namespace details
/**
* @brief will be used in BlockYReduce, get the index of reduce_num in shared
* memory
*/
__device__ __forceinline__ int SharedMemoryIndex(int index) {
return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
/*************************** Compute Functor****************************/
template <typename T, typename Enable = void>
struct DivFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] / args[1];
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
};
return val;
}
template <typename T>
struct DivFunctor<T, typename std::enable_if_t<std::is_integral<T>::value>> {
inline HOSTDEVICE T operator()(const T* args) const {
PADDLE_ENFORCE(args[1] != 0,
platform::errors::InvalidArgument(
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value."));
return args[0] / args[1];
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/
/**
* @brief BlockXReduce reduce along blockDim.x
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__syncthreads();
using details::kWarpSize;
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[bid * block_dim_x + lane];
}
};
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = 1; stride < block_dim_x; stride <<= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}
/**
* @brief BlockYReduce reduce along blockDim.y
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kMaxThread];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
T temp = shared_memory[SharedMemoryIndex(stride)];
val = reducer(val, temp);
}
shared_memory[SharedMemoryIndex(0)] = val;
}
return val;
}
} // namespace details
/*************************** Compute Function****************************/
/**
* @brief compute functor for elementwise_two, in1 and in2 has the same shape
* @brief binary function, in1 and in2 have same shape
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in1 and in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
......@@ -88,32 +160,40 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
}
/**
* @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape
* @brief ternary function, in1, in2 and in3 have same shape
* @param:
* T : the type of in1 and in2, in3
* NX: the row of in1, in2 and in3
* NY: the col of in1, in2 and in3
* BlockSize: the strid of col
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: out = in1 * in2 + in3
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseFma(OutT* out, const T* in1,
const T* in2, const T* in3,
OpFunc compute) {
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
const T* in2, const T* in3,
OpFunc compute) {
T args[3];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
args[0] = in1[idx];
args[1] = in2[idx];
args[2] = in3[idx];
out[idx] = static_cast<OutT>(compute(args));
}
}
/**
* @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY]
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX]
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
......@@ -130,13 +210,14 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
}
/**
* @brief compute functor for unary, in1 is [NX, NY]
* @brief unary function
* @param:
* T : the type of in
* NX: the row of in
* NY: the col of in
* BlockSize: the strid of col
* OpFunc: compute functor eg: relu, sigmoid, exp
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor eg: relu, exp
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
......@@ -148,6 +229,59 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
}
}
/**
* @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
* if ReduceMode == kGlobalMode then reduce between different threads, the
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called
* @typename:
* T: data type of in
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
* @param:
* reducer: reduce functor, eg: CustomSum<T>()
* reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
* true
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc,
details::ReduceMode Mode>
__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
bool reduce_last_dim) {
int block_index = blockDim.y;
if (Mode == details::ReduceMode::kGlobalMode) {
bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
// when reduce is not required for the last dim, and reduce num has been
// split into multiple threads
if (block_reduce_y) {
#pragma unroll
for (int i = 0; i < NY * NX; i++) { // reduce along blockdim.y
out[i] = details::BlockYReduce<T, OpFunc>(out[i], reducer);
}
}
// when last dimension need to be reduced
if (reduce_last_dim) {
#pragma unroll
for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
out[i] = details::BlockXReduce<T, OpFunc>(out[i], reducer);
}
}
} else { // else kLocalMode
#pragma unroll
for (int i = 0; i < NY; ++i) {
#pragma unroll
for (int j = 0; j < NX; ++j) {
out[i] = reducer(out[i], in[i * NX + j]);
}
}
}
}
} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
......@@ -13,11 +13,13 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <iostream>
#include <vector>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
namespace paddle {
namespace operators {
......@@ -104,52 +106,197 @@ struct BroadcastConfig {
#undef INT_BITS
} // namespace details
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadDataBase(T* dst, const T* __restrict__ src,
int size) {
/**
* @brief load data from src to dst, src can be 1D data or 2D data. Note that
* you can use this function when you are sure that the data will not cross the
* boundary.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* @param:
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template <typename Tx, typename Ty, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int stride_nx, int stride_ny) {
if (NY == 1 && NX == 1) {
dst[0] = static_cast<Ty>(src[threadIdx.x]);
} else if (NX == 1) {
int dx = threadIdx.x;
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
}
} else if (NY == 1) {
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = static_cast<Ty>(src[idx * stride_nx]);
}
} else {
int dx = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
dst[idy * NX + idx] =
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
}
}
}
}
/**
* @brief load data from src to dst, src can be 1D data or 2D data. When
* boundary judgment is required, you need to set a to true, and a is false by
* default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int size_nx, int size_ny,
int stride_nx, int stride_ny) {
int dx = threadIdx.x * NX;
int size = size_nx - dx;
// Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
if (IsBoundary) {
if (dx < size_nx) {
dst[0] = static_cast<Ty>(src[dx]);
}
} else {
dst[0] = static_cast<Ty>(src[dx]);
}
} else if (NX == 1) { // for NX == 1 and NY != 1
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) >= size) {
break;
for (int idy = 0; idy < NY; ++idy) {
if (IsBoundary) {
if (idy >= size_ny) {
break;
}
}
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
}
} else if (NY == 1) { // for NY == 1 and NX != 1
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= size) {
break;
}
}
dst[idx] = static_cast<Ty>(src[idx * stride_nx + dx]);
}
} else { // for NX != 1 and NY != 1
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= size) {
break;
}
}
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
if (IsBoundary) {
if (idy >= size_ny) {
break;
}
}
dst[idy * NX + idx] =
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
}
}
dst[idx] = src[idx + dx];
}
}
template <typename T, int NX, int NY, int BlockSize>
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/** @brief: ReadData
* @brief load data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* num: number of columns to be processed by the current block
*/
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
int size) {
const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int VECTORS_PER_THREAD = NX / VECTOR_SIZE;
int num) {
if (IsBoundary) { // blockDim.x * NX > num
int dx = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + dx < num) {
dst[idx] = src[idx + dx];
}
}
} else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;
int tid = threadIdx.x * kVectorsPerThread;
// Vector per thread
if (blockDim.x * NX > size) {
ReadDataBase<T, NX, NY, BlockSize>(dst, src, size);
} else {
// Vector type
using VecType = details::VectorType<T, VECTOR_SIZE>;
VecType vec_temp[VECTORS_PER_THREAD];
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
ReadDataBase<VecType, VECTORS_PER_THREAD, NY, BlockSize>(
vec_temp, vec_input, VECTORS_PER_THREAD * blockDim.x);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[i + tid];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/** @brief: ReadDataBc
* read data from src ptr when the shape of src and dst are different
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IsBoundary: whether to make boundary judgment
* @param:
* src: the source pointer
* dst: the dst pointer
* stride_nx: the stride of src
* stride_ny: the stride of src
* the shape of dst is [NY, NX]
* fix: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host;
* num: the num of out
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template <typename T, int NX, int NY, int BlockSize, int ShapeSize>
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst, const T* __restrict__ src, uint32_t fix,
details::BroadcastConfig<ShapeSize> config, int num, int stride_nx,
......@@ -162,53 +309,130 @@ __device__ __forceinline__ void ReadDataBc(
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx;
if (idx < num) {
offset = 0;
#pragma unroll
for (int i = 0; i < ShapeSize; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * config.strides[i];
if (IsBoundary) {
if (idx >= num) {
break;
}
dst[nx + ny * NX] = src[offset];
}
offset = 0;
#pragma unroll
for (int i = 0; i < ShapeSize; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * config.strides[i];
}
dst[nx + ny * NX] = src[offset];
}
}
}
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void WriteDataBase(T* dst, const T* __restrict__ src,
int size) {
int dx = threadIdx.x * NX;
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
* reduce_last_dim: according to the block split set threadIdx
*/
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal,
int size_nx, int size_ny, int stride_nx, int stride_ny,
bool reduce_last_dim) {
int base_offset = fix;
if (reduce_last_dim) {
base_offset += threadIdx.x;
} else {
base_offset += threadIdx.y;
}
if (NX == 1) {
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) >= size) {
break;
for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) {
if (base_offset >= size_ny) {
break;
}
}
uint32_t offset = index_cal(base_offset);
dst[ny] = src[offset];
base_offset += stride_ny;
}
} else {
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
if (IsBoundary) {
if (nx * stride_nx >= size_nx) {
break;
}
}
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) {
if (nx * stride_nx >= size_nx) {
break;
}
}
uint32_t offset = index_cal(base_offset);
dst[nx + ny * NX] = src[offset];
base_offset += stride_ny;
}
}
dst[idx + dx] = src[idx];
}
}
template <typename T, int NX, int NY, int BlockSize>
/** @brief: WriteData
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* num: number of columns to be processed by the current block
*/
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
int size) {
const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int VECTORS_PER_THREAD = NX / VECTOR_SIZE;
// Vector per thread
if (blockDim.x * NX > size) {
WriteDataBase<T, NX, NY, BlockSize>(dst, src, size);
int num) {
if (IsBoundary) {
int dx = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) < num) {
dst[idx + dx] = src[idx];
}
}
} else {
// Vector type
using VecType = details::VectorType<T, VECTOR_SIZE>;
VecType vec_temp[VECTORS_PER_THREAD];
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize;
int dx = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < VECTORS_PER_THREAD; ++idx) {
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[dx + idx] = vec_temp[idx];
}
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
WriteDataBase<VecType, VECTORS_PER_THREAD, NY, BlockSize>(
vec_dst, vec_temp, VECTORS_PER_THREAD * blockDim.x);
}
}
......
......@@ -16,6 +16,65 @@
namespace paddle {
namespace operators {
namespace kernel_primitives {}
namespace kernel_primitives {
namespace details {
static __device__ __forceinline__ platform::float16 ExpFunctor(
platform::float16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float ExpFunctor(float x) { return expf(x); }
static __device__ __forceinline__ double ExpFunctor(double x) { return exp(x); }
static __device__ __forceinline__ platform::float16 LogFunctor(
platform::float16 x) {
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); }
static __device__ __forceinline__ double LogFunctor(double x) { return log(x); }
} // namespace details
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template <typename Tx, typename Ty = Tx>
struct ExpLogitTransformer {
HOSTDEVICE explicit inline ExpLogitTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(details::ExpFunctor(x[0]));
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(details::ExpFunctor(x));
}
};
// Post processing function for sum, max, min, prod, any
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]);
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
}
};
// Post processing function for mean
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T* x) const { return x[0] * n_inv; }
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册