未验证 提交 9a8a4c77 编写于 作者: N niuliling123 提交者: GitHub

Delete cub_reduce.h and modified the TensorReduce to TensorReduceFunctorImpl (#38197)

上级 431a2d6a
......@@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
......@@ -28,16 +28,6 @@ namespace operators {
using framework::Tensor;
using framework::DDim;
template <typename Tout>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline Tout operator()(const U& x) const {
return static_cast<Tout>(x);
}
};
template <typename T>
class CUDABroadcastTensorsGradOpKernel : public framework::OpKernel<T> {
public:
......@@ -99,9 +89,9 @@ class CUDABroadcastTensorsGradOpKernel : public framework::OpKernel<T> {
} else {
// reduce_sum implementation on CUDA
auto stream = context.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input_tensor, output_tensor, reduce_dims_vec, static_cast<T>(0),
cub::Sum(), IdentityFunctor<T>(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*input_tensor, output_tensor, kps::IdentityFunctor<T>(),
reduce_dims_vec, stream);
}
}
}
......
......@@ -15,20 +15,16 @@ limitations under the License. */
#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
struct BitwiseAdd {
// Bitwise add operator, returns <tt>a + b</tt>
template <typename T>
inline T initial() { return static_cast<T>(true); }
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return a & b;
......@@ -67,9 +63,9 @@ class CompareReduceOpKernel
reduce_dims.resize(tmp.dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
auto stream = context.cuda_device_context().stream();
TensorReduce<bool, bool, BitwiseAdd, IdentityFunctor<bool>>(
tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor<bool>(),
stream);
TensorReduceFunctorImpl<bool, bool, BitwiseAdd,
kps::IdentityFunctor<bool>>(
tmp, z, kps::IdentityFunctor<bool>(), reduce_dims, stream);
}
}
};
......
......@@ -33,7 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/fast_divmod.h"
namespace paddle {
......@@ -41,8 +41,6 @@ namespace operators {
#define MAX_INPUT_NUM 2
namespace kps = paddle::operators::kernel_primitives;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "thrust/device_vector.h"
#endif
......@@ -237,15 +237,6 @@ struct KronGradElemFunctor<platform::complex<T>> {
const int ndims_;
};
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
struct KronGradOpFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout,
......@@ -314,14 +305,12 @@ struct KronGradOpFunctor {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor(),
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
}
if (dy) {
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor(),
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dout_y, dy, kps::IdentityFunctor<T>(), {1}, stream);
}
#else
auto* place = dev_ctx.eigen_device();
......
......@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/pten/include/linalg.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
namespace paddle {
......@@ -39,24 +39,14 @@ namespace operators {
using framework::Tensor;
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor>(*input, output, reduce_dims,
static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*input, output, kps::IdentityFunctor<T>(), reduce_dims, stream);
#else
ReduceKernelFunctor<DeviceContext, T, ops::SumFunctor>(
input, output, reduce_dims, true, false, ctx)
......
......@@ -13,11 +13,11 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
namespace detail {
namespace details {
TEST(test_reduce_rank_check, all) {
using EnforceNotMet = paddle::platform::EnforceNotMet;
......@@ -39,15 +39,15 @@ TEST(test_reduce_rank_check, all) {
}
if (is_valid) {
CheckReduceRankIsValid(reduce_rank, rank);
CheckReduceRank(reduce_rank, rank);
} else {
ASSERT_THROW(CheckReduceRankIsValid(reduce_rank, rank),
ASSERT_THROW(CheckReduceRank(reduce_rank, rank),
paddle::platform::EnforceNotMet);
}
}
}
}
} // namespace detail
} // namespace details
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" // NOLINT
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
namespace paddle {
namespace operators {
namespace detail {
template <typename T, size_t ElementCount>
struct Array {
public:
HOSTDEVICE inline Array() {}
HOSTDEVICE inline T& operator[](size_t index) { return data_[index]; }
HOSTDEVICE inline const T& operator[](size_t index) const {
return data_[index];
}
HOSTDEVICE constexpr inline size_t size() const { return ElementCount; }
template <typename VectorLikeType>
static inline Array<T, ElementCount> From(const VectorLikeType& vec) {
PADDLE_ENFORCE_EQ(vec.size(), ElementCount,
platform::errors::InvalidArgument(
"Cub reduce Array: size not match. Received "
"vec.size() %d != ElementCount %d.",
vec.size(), ElementCount));
size_t n = static_cast<size_t>(vec.size());
Array<T, ElementCount> ret;
for (size_t i = 0; i < n; ++i) ret[i] = vec[i];
return ret;
}
private:
T data_[ElementCount];
};
// reduce the 1d array to one element
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num) {
int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
typedef cub::BlockReduce<MPType, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
MPType local_data = init;
for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) {
local_data = static_cast<MPType>(
reducer(local_data, static_cast<MPType>(transformer(x[i]))));
}
__syncthreads();
local_data = BlockReduce(temp_storage).Reduce(local_data, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = static_cast<Ty>(local_data);
}
}
// reduce the last axis of 2d array
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num) {
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
MPType reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
reduce_var =
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x + idx_y])));
__syncthreads();
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim, int Rank, int ReduceRank>
__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num, Array<int, Rank> x_strides,
Array<int, ReduceRank> reduce_dim,
Array<int, ReduceRank> reduce_strides,
Array<int, Rank - ReduceRank> left_dim,
Array<int, Rank - ReduceRank> left_strides) {
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
Array<int, Rank> sub_index;
int left_idx = blockIdx.x;
for (int i = 0; i < Rank - ReduceRank; ++i) {
sub_index[left_dim[i]] = left_idx / left_strides[i];
left_idx %= left_strides[i];
}
int reduce_idx = threadIdx.x;
for (int j = 0; j < ReduceRank; ++j) {
sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
reduce_idx %= reduce_strides[j];
}
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
MPType reduce_var = static_cast<MPType>(transformer(x[idx_x]));
for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
int reduce_idx = i;
for (int j = 0; j < ReduceRank; ++j) {
sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
reduce_idx %= reduce_strides[j];
}
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
reduce_var = static_cast<MPType>(
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x]))));
}
__syncthreads();
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}
static inline std::vector<int> GetStrides(const std::vector<int>& dims) {
int n = static_cast<int>(dims.size());
if (n == 0) return std::vector<int>();
std::vector<int> strides(n);
strides.back() = 1;
for (int i = n - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dims[i + 1];
}
return strides;
}
static inline std::vector<int> GetStrides(const std::vector<int>& dims,
const std::vector<int>& idx) {
int n = static_cast<int>(idx.size());
if (n == 0) return std::vector<int>();
std::vector<int> strides(n);
strides.back() = 1;
for (int i = n - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dims[idx[i + 1]];
}
return strides;
}
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
static inline int GetDesiredBlockDim(int block_dim) {
return block_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(block_dim)));
}
static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
if (rank % 2 == 0) {
PADDLE_ENFORCE_EQ(reduce_rank, rank / 2,
platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, "
"reduce_rank must be %d, but got %d.",
rank, rank / 2, reduce_rank));
} else {
auto lower_rank = (rank - 1) / 2;
auto upper_rank = (rank + 1) / 2;
PADDLE_ENFORCE_EQ(
reduce_rank == lower_rank || reduce_rank == upper_rank, true,
platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
"must be %d or %d, but got %d.",
rank, lower_rank, upper_rank, reduce_rank));
}
}
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<!std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const MPType& init,
int reduce_num, gpuStream_t stream) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
}
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const MPType& init,
int reduce_num, gpuStream_t stream) {
int element_per_block = BlockDim * 10;
int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block;
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<MPType>(
framework::make_ddim(
{static_cast<int64_t>(block_per_grid * sizeof(MPType))}),
place);
// each block reduce number to interim result
ReduceKernel1D<Tx, MPType, MPType, ReduceOp, TransformOp,
BlockDim><<<block_per_grid, BlockDim, 0, stream>>>(
x_data, temp_storage, reducer, transformer, init, reduce_num);
// reduce all number to final result
ReduceKernel1D<MPType, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<1, BlockDim, 0, stream>>>(
temp_storage, y_data, reducer, transformer, init, block_per_grid);
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void TensorReduceImpl(
const Tx* x_data, Ty* y_data, const platform::Place& place,
const ReduceOp& reducer, const TransformOp& transformer, const Ty& init,
int left_num, int reduce_num, const std::vector<int>& x_strides,
const std::vector<int>& reduce_dim, const std::vector<int>& reduce_strides,
const std::vector<int>& left_dim, const std::vector<int>& left_strides,
gpuStream_t stream) {
using MPType = typename details::MPTypeTrait<Ty>::Type;
MPType init_mp = static_cast<MPType>(init);
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init_mp, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
} break
int rank = x_strides.size();
int reduce_rank = reduce_strides.size();
if (rank == reduce_rank) {
LaunchCubReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim>(
x_data, y_data, place, reducer, transformer, init_mp, reduce_num,
stream);
return;
}
if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
ReduceKernel2D<Tx, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<left_num, BlockDim, 0, stream>>>(
x_data, y_data, reducer, transformer, init_mp, reduce_num);
return;
}
/*
if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) {
// TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced.
// Currently, it is handled by code below, but inefficient
return;
}
*/
/**
* Since we have combined the adjacent reduce dimensions inside TensorReduce,
* The reduce ranks and non-reduce ranks must be interleaving. That is to say,
* the rank of Tensor must be `1010...` or `0101...` where 1 represents that
* the dimension is about to be reduced.
*
* Therefore,
* If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2.
* If rank is even, only need to switch-case rank/2.
*
* The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12,
* it would speed up compiling and make the binary size lower.
*/
CheckReduceRankIsValid(reduce_rank, rank);
switch (rank) {
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
}
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
} // namespace detail
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, const Ty& init,
const ReduceOp& reducer, const TransformOp& transformer,
gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims());
std::vector<int> new_x_dim, new_reduce_dims;
int is_reduced = 0;
for (auto e : origin_reduce_dims) {
auto pos = e >= 0 ? e : e + x_dim.size();
is_reduced |= 1 << e;
}
for (int i = 0; i < x_dim.size(); i++) {
if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
new_x_dim.push_back(x_dim[i]);
if ((is_reduced >> i) & 1)
new_reduce_dims.push_back(new_x_dim.size() - 1);
} else {
new_x_dim[new_x_dim.size() - 1] *= x_dim[i];
}
}
x_dim = new_x_dim;
origin_reduce_dims = new_reduce_dims;
int x_rank = static_cast<int>(x_dim.size());
std::set<int> left_set, reduce_set;
for (int i = 0; i < x_rank; ++i) left_set.insert(i);
for (auto e : origin_reduce_dims) {
left_set.erase(e);
reduce_set.insert(e);
}
std::vector<int> reduce_dim(reduce_set.begin(), reduce_set.end());
std::vector<int> left_dim(left_set.begin(), left_set.end());
std::vector<int> x_strides = detail::GetStrides(x_dim);
std::vector<int> reduce_strides = detail::GetStrides(x_dim, reduce_dim);
std::vector<int> left_strides = detail::GetStrides(x_dim, left_dim);
int reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];
int left_num = 1;
if (left_dim.size()) left_num = left_strides[0] * x_dim[left_dim[0]];
std::vector<int> y_dim(left_dim.size());
for (int i = 0; i < left_dim.size(); ++i) {
y_dim[i] = x_dim[left_dim[i]];
}
auto x_data = x.data<Tx>();
auto y_data = y->mutable_data<Ty>(x.place());
if (reduce_num == 1) {
auto out_dims = y->dims();
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
return;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
detail::TensorReduceImpl<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
x_data, y_data, x.place(), reducer, transformer, init, left_num, \
reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, \
left_strides, stream); \
} break
switch (detail::GetDesiredBlockDim(reduce_num)) {
CUB_BLOCK_DIM_CASE(512);
CUB_BLOCK_DIM_CASE(256);
CUB_BLOCK_DIM_CASE(128);
CUB_BLOCK_DIM_CASE(64);
CUB_BLOCK_DIM_CASE(32);
CUB_BLOCK_DIM_CASE(16);
CUB_BLOCK_DIM_CASE(8);
CUB_BLOCK_DIM_CASE(4);
CUB_BLOCK_DIM_CASE(2);
}
#undef CUB_BLOCK_DIM_CASE
}
template <typename Tx, typename ReduceOp, template <typename> class TransformOp>
struct TensorReduceFunctor {
const framework::Tensor& x;
framework::Tensor* y;
std::vector<int> origin_reduce_dims;
const double& init;
const ReduceOp& reducer;
gpuStream_t stream;
TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, const double& init,
const ReduceOp& reducer, gpuStream_t stream)
: x(x),
y(y),
origin_reduce_dims(origin_reduce_dims),
init(init),
reducer(reducer),
stream(stream) {}
template <typename Ty>
void apply() const {
const Ty& init_cast = static_cast<Ty>(init);
TensorReduce<Tx, Ty, ReduceOp, TransformOp<Ty>>(x, y, origin_reduce_dims,
init_cast, reducer,
TransformOp<Ty>(), stream);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
template <typename T>
using CUDAFrobeniusNormKernel =
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
template <typename T>
......
......@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/operators/squeeze_op.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#define MAX_RANK_SUPPORTED 6
......@@ -39,24 +39,14 @@ using framework::To32BitIndex;
constexpr int kMULMKLDNNINT8 = 1;
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
void ReduceSumForSolve(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, bool keep_dim,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor>(*input, output, reduce_dims,
static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*input, output, kps::IdentityFunctor<T>(), reduce_dims, stream);
#else
ReduceKernelFunctor<DeviceContext, T, ops::SumFunctor>(
input, output, reduce_dims, keep_dim, false, ctx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册