未验证 提交 480b284c 编写于 作者: N niuliling123 提交者: GitHub

modified reduce_max, reduce_min, reduce_prod to higher_performance implementation. (#32974)

上级 20eafd79
......@@ -13,46 +13,98 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include <cmath>
#include <limits>
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif
namespace paddle {
namespace operators {
template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMin {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::max());
}
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b < a) ? b : a;
}
};
template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMax {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::lowest());
}
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b > a) ? b : a;
}
};
template <typename T>
// for cub::Reduce
template <typename Tx, typename Ty = Tx>
struct CustomSum {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = detail::DivideFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomMul {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(1.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b * a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr {
using Transformer = detail::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(false); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b || a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd {
using Transformer = detail::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(true); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b && a;
}
};
} // namespace operators
} // namespace paddle
......@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_max,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MaxFunctor>);
// reduce_max
REGISTER_OP_CUDA_KERNEL(
reduce_max, ops::ReduceCudaKernel<float, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMax>);
......@@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_min,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MinFunctor>);
// reduce_min
REGISTER_OP_CUDA_KERNEL(
reduce_min, ops::ReduceCudaKernel<float, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMin>);
......@@ -30,32 +30,59 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
namespace paddle {
namespace operators {
namespace detail {
// Post processing function for sum, max, min, prod, any
template <typename T>
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
DEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
DEVICE inline T operator()(const T& x) const { return x; }
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
}
};
// Post processing function for mean
template <typename T>
struct DivideFunctor {
DEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
DEVICE inline T operator()(const T& x) const { return x * n_inv; }
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
static inline std::vector<int> GetReduceDim(const std::vector<int>& dims,
int dim_size, bool reduce_all) {
std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(dim_size);
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i;
}
} else {
for (auto e : dims) {
PADDLE_ENFORCE_LT(e, dim_size,
paddle::platform::errors::InvalidArgument(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should less than x_dims, but got %d.",
dim_size, e));
reduce_dims.push_back(e >= 0 ? e : e + dim_size);
}
}
return reduce_dims;
}
static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
......@@ -65,8 +92,9 @@ static inline int GetLastPow2(int n) {
return std::max(1, n - (n >> 1));
}
static inline std::vector<int> GetStrides(const std::vector<int>& dims,
const std::vector<int>& idx) {
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
const std::vector<int>& idx) {
int n = static_cast<int>(idx.size());
if (n == 0) return std::vector<int>();
std::vector<int> strides(n);
......@@ -78,18 +106,18 @@ static inline std::vector<int> GetStrides(const std::vector<int>& dims,
}
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
constexpr int kMaxThread = 256;
#else
constexpr int kMaxBlockDim = 512;
constexpr int kMaxThread = 128;
#endif
static inline int GetDesiredBlockDim(int block_dim) {
return block_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(block_dim)));
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
}
static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
if (rank % 2 == 0) {
PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
platform::errors::InvalidArgument(
......@@ -108,8 +136,9 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
}
}
// convert dims from vector to array
template <typename T, size_t ElementCount, typename VectorLikeType>
static inline paddle::framework::Array<T, ElementCount> from(
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
const VectorLikeType& vec) {
PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
platform::errors::InvalidArgument(
......@@ -118,17 +147,21 @@ static inline paddle::framework::Array<T, ElementCount> from(
vec.size(), ElementCount));
size_t n = static_cast<size_t>(vec.size());
paddle::framework::Array<T, ElementCount> ret;
for (size_t i = 0; i < n; ++i) ret[i] = vec[i];
for (size_t i = 0; i < n; ++i) {
ret[i] = vec[i];
}
return ret;
}
} // namespace detail
using Tensor = framework::Tensor;
enum ReduceType {
kReduceAll = 0x00,
kReduceLastDim = 0x01,
kReduceAll = 0x00, // when reduce_rank == x_rank
kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1;
kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim
kReduceAny = 0x03,
kReduceAny = 0x03, // when reduce_dim.size() > 1
};
// reduce config
......@@ -141,21 +174,24 @@ struct ReduceConfig {
void Run() {
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim();
// step2: get the strides of dim for reduceAny and reduceLastDim
SetStrides();
// step3: get the type of reduce
SetReduceType();
// step4: set the block and grid for launch kernel
SetBlockDim();
}
// when should_reduce_again is true, we need malloc temp space for temp data
void SetOutputData(Ty* y_data, const platform::Place& place,
framework::Tensor& tmp) {
framework::Tensor* tmp) {
if (should_reduce_again) {
output_data = tmp.mutable_data<Ty>(
output_data = tmp->mutable_data<Ty>(
framework::make_ddim(
{static_cast<int64_t>(left_num * grid.y * sizeof(Ty))}),
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}),
place);
} else {
output_data = y_data;
......@@ -168,50 +204,70 @@ struct ReduceConfig {
// --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
void SetReduceDim() {
std::set<int> reduce_set;
for (auto e : reduce_dims_origin) {
auto pos = e >= 0 ? e : e + x_dim.size();
reduce_set.insert(pos);
}
std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
// get reduce_dim
// update reduce_dim and x_dim
std::vector<int> x_new_dim;
reduce_dim.push_back(reduce_dim_temp[0]);
x_new_dim.push_back(x_dim[0]);
int idx_reduce = 1;
int num = 0;
if (reduce_dim_temp.size() > 1) {
int num = 0; // for update axis
reduce_dim.push_back(reduce_dim_temp[0]);
for (int idx = 1; idx < reduce_dim_temp.size(); idx++) {
// update x_dim
if (reduce_dim_temp[idx] - reduce_dim_temp[idx - 1] == 1) {
x_dim[reduce_dim_temp[idx - 1]] *= x_dim[reduce_dim_temp[idx]];
x_dim.erase(x_dim.begin() + reduce_dim_temp[idx]);
num++;
for (int i = 1; i < x_dim.size(); i++) {
if ((idx_reduce < reduce_dim_temp.size()) &&
(i == reduce_dim_temp[idx_reduce])) {
int result =
reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
bool is_equal = ((result - num) == 1);
if (is_equal) {
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
num++;
} else {
reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
x_new_dim.push_back(x_dim[i]);
}
idx_reduce++;
} else {
reduce_dim.push_back(reduce_dim_temp[idx] - num);
x_new_dim.push_back(x_dim[i]);
}
}
} else {
reduce_dim = reduce_dim_temp;
x_new_dim = x_dim;
}
// update new_x_dim and new_reduce_dim
std::vector<int> new_x_dim, new_reduce_dim_temp;
// update x_dim
x_dim = x_new_dim;
std::vector<int>().swap(x_new_dim);
std::vector<int> reduce_dim_new;
int is_reduced = 0;
for (auto e : reduce_dim) {
is_reduced |= 1 << e;
}
std::vector<int>().swap(reduce_dim);
for (int i = 0; i < x_dim.size(); i++) {
if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
new_x_dim.push_back(x_dim[i]);
x_new_dim.push_back(x_dim[i]);
if ((is_reduced >> i) & 1)
new_reduce_dim_temp.push_back(new_x_dim.size() - 1);
reduce_dim_new.push_back(x_new_dim.size() - 1);
} else {
new_x_dim[new_x_dim.size() - 1] *= x_dim[i];
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
}
}
x_dim = new_x_dim;
reduce_dim = new_reduce_dim_temp;
x_dim = x_new_dim;
reduce_dim = reduce_dim_new;
int x_rank = static_cast<int>(x_dim.size());
std::set<int> left_set;
......@@ -237,9 +293,9 @@ struct ReduceConfig {
idx_dim.push_back(i);
}
x_strides = detail::GetStrides(x_dim, idx_dim);
reduce_strides = detail::GetStrides(x_dim, reduce_dim);
left_strides = detail::GetStrides(x_dim, left_dim);
x_strides = detail::GetDimStrides(x_dim, idx_dim);
reduce_strides = detail::GetDimStrides(x_dim, reduce_dim);
left_strides = detail::GetDimStrides(x_dim, left_dim);
reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];
left_num = 1;
......@@ -256,13 +312,17 @@ struct ReduceConfig {
void SetReduceType() {
int rank = x_dim.size();
int reduce_rank = reduce_dim.size();
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
(left_num > REDUCE_SPLIT_BOUNDARY);
if (rank == reduce_rank) {
reduce_type = static_cast<int>(ReduceType::kReduceAll);
} else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1) {
} else if (reduce_rank == 1 &&
((rank == 2 && is_large_enough) || rank != 2)) {
// ReduceFirstDim and reduceSecondDim
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
......@@ -277,7 +337,7 @@ struct ReduceConfig {
// for others: block(block_num, 1) , grid(left_num, 1)
void SetBlockDim() {
// init
int block_num = detail::GetDesiredBlockDim(reduce_num);
int block_num = detail::GetBlockDim(reduce_num);
should_reduce_again = false;
dim3 block_dim(block_num, 1);
......@@ -302,7 +362,7 @@ struct ReduceConfig {
// init
int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= 512) {
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
blocking_size = detail::GetLastPow2(reduce_num / num_block);
if (blocking_size <= 1) {
......@@ -352,6 +412,9 @@ struct ReduceConfig {
dim3 grid;
};
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim>
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
......@@ -362,18 +425,25 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
reduce_var = reducer(reduce_var, static_cast<Ty>(x[idx_x + idx_y]));
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) {
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
}
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = transformer(reduce_var);
y[blockIdx.x] = reduce_var;
}
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
ReduceOp reducer,
......@@ -383,25 +453,29 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idy = blockIdx.y * block_size;
Ty temp = init;
Ty reduce_var = init;
if (idx < left_num) {
int loop = reduce_num - idy;
loop = loop > block_size ? block_size : loop;
for (int iy = 0; iy < loop; iy++) {
int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
reduce_var = reducer(reduce_var, static_cast<Ty>(x[id]));
reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
}
y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
static_cast<Ty>(transformer(reduce_var));
reduce_var;
}
}
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
__device__ __forceinline__ void ReduceAny(
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
int reduce_num, paddle::framework::Array<int, Rank> x_strides,
paddle::framework::Array<int, ReduceRank> reduce_dim,
paddle::framework::Array<int, ReduceRank> reduce_strides,
......@@ -423,20 +497,26 @@ __device__ __forceinline__ void ReduceAny(
}
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
Ty reduce_var = static_cast<Ty>(x[idx_x]);
for (int k = 0; k < Rank; ++k) {
idx_x += (sub_index[k] * x_strides[k]);
}
Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
int reduce_idx = i;
for (int j = 0; j < ReduceRank; ++j) {
sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
reduce_idx %= reduce_strides[j];
}
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
reduce_var =
static_cast<Ty>(reducer(reduce_var, static_cast<Ty>(x[idx_x])));
for (int k = 0; k < Rank; ++k) {
idx_x += (sub_index[k] * x_strides[k]);
}
reduce_var = static_cast<Ty>(
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
}
__syncthreads();
......@@ -444,10 +524,11 @@ __device__ __forceinline__ void ReduceAny(
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = transformer(reduce_var);
y[blockIdx.x] = reduce_var;
}
}
// module function designed for global function
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank, int ReduceType>
__device__ __forceinline__ void ReduceModule(
......@@ -458,17 +539,20 @@ __device__ __forceinline__ void ReduceModule(
paddle::framework::Array<int, ReduceRank> reduce_strides,
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
// reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
if (ReduceType == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp, BlockDim>(
x, y, reducer, transformer, init, reduce_num);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (ReduceType == ReduceType::kReduceHigherDim) {
ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
// reduce_rank >= 2
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
x, y, reducer, transformer, init, reduce_num, x_strides, reduce_dim,
x, y, reducer, transformer, reduce_num, x_strides, reduce_dim,
reduce_strides, left_dim, left_strides);
}
}
......@@ -491,23 +575,22 @@ __global__ void ReduceKernelFunction(
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp, int kRank, int kReduceRank>
static void launchKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const Ty& init,
static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer,
const TransformOp& transformer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) {
#define CUB_REDUCE_TYPE_CASE(type) \
case type: { \
constexpr auto kReduceType = type; \
ReduceKernelFunction< \
Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \
kReduceType><<<config.grid, config.block, 0, stream>>>( \
x_data, config.output_data, reducer, transformer, init, \
config.reduce_num, config.left_num, config.blocking_size, \
detail::from<int, kRank>(config.x_strides), \
detail::from<int, kReduceRank>(config.reduce_dim), \
detail::from<int, kReduceRank>(config.reduce_strides), \
detail::from<int, kRank - kReduceRank>(config.left_dim), \
detail::from<int, kRank - kReduceRank>(config.left_strides)); \
#define CUB_REDUCE_TYPE_CASE(type) \
case type: { \
constexpr auto kReduceType = type; \
ReduceKernelFunction< \
Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \
kReduceType><<<config.grid, config.block, 0, stream>>>( \
x_data, config.output_data, reducer, transformer, init, \
config.reduce_num, config.left_num, config.blocking_size, \
detail::VectorToArray<int, kRank>(config.x_strides), \
detail::VectorToArray<int, kReduceRank>(config.reduce_dim), \
detail::VectorToArray<int, kReduceRank>(config.reduce_strides), \
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim), \
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides)); \
} break
switch (config.reduce_type) {
......@@ -523,22 +606,22 @@ static void launchKernel(const Tx* x_data, Ty* y_data,
ReduceKernelFunction<
Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128, kRank, kReduceRank,
ReduceType::kReduceHigherDim><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, detail::IdentityFunctor<Ty>(),
init, config.grid.y, config.left_num, config.grid.y,
detail::from<int, kRank>(config.x_strides),
detail::from<int, kReduceRank>(config.reduce_dim),
detail::from<int, kReduceRank>(config.reduce_strides),
detail::from<int, kRank - kReduceRank>(config.left_dim),
detail::from<int, kRank - kReduceRank>(config.left_strides));
config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y,
detail::VectorToArray<int, kRank>(config.x_strides),
detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
}
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void launchReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place,
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer,
const TransformOp& transformer, const Ty& init,
const TransformOp& transformer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) {
int reduce_rank = config.reduce_strides.size();
int rank = config.x_strides.size();
......@@ -552,28 +635,11 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data,
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
launchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
x_data, y_data, place, reducer, transformer, init, stream, config); \
LaunchKernel<Tx, Ty, BlockDim, ReduceOp, TransformOp, kRank, kReduceRank>( \
x_data, y_data, reducer, transformer, init, stream, config); \
} break
// launch CUB::Reduce
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, init, stream);
return;
}
detail::CheckReduceRankIsValid(reduce_rank, rank);
detail::CheckReduceRank(reduce_rank, rank);
switch (rank) {
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
......@@ -595,23 +661,25 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data,
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, const Ty& init,
const ReduceOp& reducer, const TransformOp& transformer,
gpuStream_t stream) {
template <typename Tx, typename Ty,
template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims,
gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run();
config.Run(); // get the parameters of LaunchReduceKernel
auto x_data = x.data<Tx>();
auto y_data = y->mutable_data<Ty>(x.place());
framework::Tensor tmp;
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
// y_data;
config.SetOutputData(y_data, x.place(), tmp);
framework::Tensor tmp;
config.SetOutputData(y_data, x.place(), &tmp);
if (config.reduce_num == 1) {
auto out_dims = y->dims();
......@@ -619,17 +687,36 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
y->Resize(out_dims);
return;
}
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
// launch CUB::Reduce
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, TransformOp(config.reduce_num));
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
x.place());
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
launchReduceKernel<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
x_data, y_data, x.place(), reducer, transformer, init, stream, \
config); \
return;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
LaunchReduceKernel<Tx, Ty, block_dim, ReduceOp<Tx, Ty>, TransformOp>( \
x_data, y_data, reducer, TransformOp(config.reduce_num), \
reducer.initial(), stream, config); \
} break
switch (detail::GetDesiredBlockDim(config.reduce_num)) {
CUB_BLOCK_DIM_CASE(512);
switch (detail::GetBlockDim(config.reduce_num)) {
CUB_BLOCK_DIM_CASE(256);
CUB_BLOCK_DIM_CASE(128);
CUB_BLOCK_DIM_CASE(64);
......@@ -642,5 +729,46 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
#undef CUB_BLOCK_DIM_CASE
}
template <typename Tx, template <typename, typename> class ReduceOp>
struct TensorReduceFunc {
const framework::Tensor& x;
framework::Tensor* y;
std::vector<int> origin_reduce_dims;
gpuStream_t stream;
TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, gpuStream_t stream)
: x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}
template <typename Ty>
void apply() const {
TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
}
};
template <typename T, template <typename, typename> class ReduceOp>
class ReduceCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
const Tensor* input = context.Input<Tensor>("X");
Tensor* output = context.Output<Tensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
std::vector<int> reduce_dims =
detail::GetReduceDim(dims, input->dims().size(), reduce_all);
gpuStream_t stream = context.cuda_device_context().stream();
if (out_dtype >= 0) {
framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunc<T, ReduceOp>(*input, output, reduce_dims, stream));
} else {
TensorReduceFunctorImpl<T, T, ReduceOp>(*input, output, reduce_dims,
stream);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,26 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
// reduce_prod
#ifdef __HIPCC__
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed)
REGISTER_OP_CUDA_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_prod, ops::ReduceCudaKernel<float, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMul>);
#else
REGISTER_OP_CUDA_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_prod, ops::ReduceCudaKernel<float, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMul>);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册