未验证 提交 524389ee 编写于 作者: N niuliling123 提交者: GitHub

Add the transformop parameter in TensorReduceFunctorImpl (#38135)

* Add the transformop parameter in TensorReduceFunctorImpl
上级 be874c08
......@@ -18,29 +18,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Tx, typename Ty = Tx>
struct SquareTransformer {
HOSTDEVICE explicit inline SquareTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]) * static_cast<Ty>(x[0]);
}
};
template <typename Tx, typename Ty = Tx>
struct SquareSum {
using Transformer = SquareTransformer<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <>
class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
......@@ -97,8 +74,10 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
}
Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>(
{1}, dev_ctx);
TensorReduceFunctorImpl<platform::float16, float, SquareSum>(
*input, &tmp, reduce_dims, dev_ctx.stream());
TensorReduceFunctorImpl<platform::float16, float, kps::AddFunctor,
kps::SquareFunctor<platform::float16, float>>(
*input, &tmp, kps::SquareFunctor<platform::float16, float>(),
reduce_dims, dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
......@@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dy, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
......@@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(*dout, dy, reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
*dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}
......
......@@ -16,11 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
......@@ -165,8 +164,8 @@ class AttnMatMul {
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*d_output, d_bias, kps::IdentityFunctor<T>(), {0, 1}, stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
......
......@@ -24,7 +24,6 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel(
}
}
template <typename Tx, typename Ty = Tx>
struct ExpAndSum {
using Transformer = kps::ExpFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename T>
__global__ void ScaleLogitKernel(T* logits, const float scale, const int64_t N,
const int64_t D) {
......@@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* logits_max_buff = logits_max.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, CustomMax>(softmax_2d, &logits_max, {1},
dev_ctx.stream());
TensorReduceFunctorImpl<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
softmax_2d, &logits_max, kps::IdentityFunctor<T>(), {1},
dev_ctx.stream());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
......@@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, ExpAndSum>(softmax_2d, &sum_exp_logits, {1},
dev_ctx.stream());
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
softmax_2d, &sum_exp_logits, kps::ExpFunctor<T>(), {1},
dev_ctx.stream());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
......
......@@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(x);
}
};
template <typename T>
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
};
template <typename T>
struct AbsFunctor {
HOSTDEVICE explicit inline AbsFunctor() {}
HOSTDEVICE explicit inline AbsFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(inline_abs(x));
}
......@@ -106,48 +95,6 @@ struct PowFunctor {
float porder;
};
template <typename Tx, typename Ty = Tx>
struct AbsAndMin {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a < b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct AbsAndMax {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(-std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a > b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct NonzeroAndSum {
using Transformer = NonzeroFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct IdentityAndSum {
using Transformer = IdentityFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
using MT = typename details::MPTypeTrait<T>::Type;
if (porder == 0) {
TensorReduceFunctorImpl<T, T, NonzeroAndSum>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
*in_x, out_norm, NonzeroFunctor<T>(), reduce_axis, stream);
} else if (porder == INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMax>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::MaxFunctor, AbsFunctor<T>>(
*in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else if (porder == -INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMin>(*in_x, out_norm, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::MinFunctor, AbsFunctor<T>>(
*in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else {
framework::Tensor tmp_x;
tmp_x.mutable_data<T>(xdim, ctx.GetPlace());
......@@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
cuda_ctx, ins, &outs, func);
framework::Tensor tmp_y;
tmp_y.mutable_data<T>(ndim, ctx.GetPlace());
TensorReduceFunctorImpl<T, T, IdentityAndSum>(tmp_x, &tmp_y, reduce_axis,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
tmp_x, &tmp_y, kps::IdentityFunctor<T>(), reduce_axis, stream);
const framework::Tensor* tmp_norm = &tmp_y;
ins = {tmp_norm};
outs = {out_norm};
......
......@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
......@@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel<T> {
} else if (pooling_type == "avg") {
std::vector<int> reduce_dim;
int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);
if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__)
auto stream = dev_ctx.stream();
TensorReduceFunctorImpl<T, T, CustomMean>(*in_x, out, reduce_dim,
stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor,
kps::DivideFunctor<T>>(
*in_x, out, kps::DivideFunctor<T>(reduce_num), reduce_dim,
stream);
#else // for cpu
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
......@@ -123,13 +123,6 @@ class PreluOpGradFunctor {
}
};
struct IdentityFunctor {
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
class CUDAPReluGradKernel : public framework::OpKernel<T> {
public:
......@@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i);
}
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dalpha_tmp, dalpha, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dalpha_tmp, dalpha, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
};
......
......@@ -13,8 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_all,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>);
ops::ReduceCudaKernel<bool, kps::LogicalAndFunctor, kps::IdentityFunctor>);
......@@ -13,9 +13,8 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_any,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>);
ops::ReduceCudaKernel<bool, kps::LogicalOrFunctor, kps::IdentityFunctor>);
......@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_max
REGISTER_OP_CUDA_KERNEL(
reduce_max, ops::ReduceCudaKernel<float, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMax>);
reduce_max,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
......@@ -13,11 +13,11 @@
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_mean, ops::ReduceCudaKernel<bool, paddle::operators::CustomMean>,
ops::ReduceCudaKernel<float, paddle::operators::CustomMean>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMean>);
reduce_mean,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);
......@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_min
REGISTER_OP_CUDA_KERNEL(
reduce_min, ops::ReduceCudaKernel<float, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMin>);
reduce_min,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MinFunctor, kps::IdentityFunctor>);
......@@ -44,11 +44,11 @@ namespace cub = hipcub;
#define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
namespace kps = paddle::operators::kernel_primitives;
namespace paddle {
namespace operators {
namespace kps = paddle::operators::kernel_primitives;
namespace details {
static inline int GetLastPow2(int n) {
......@@ -722,12 +722,12 @@ __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer,
}
}
template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, MPType init,
const ReduceOp& reducer,
const TransformOp& transform, MPType init,
gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer;
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
int stride_left = config.reduce_num;
......@@ -743,15 +743,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
OneDimIndexCal><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
init, config.reduce_num, config.left_num, config.reduce_last_dim,
reduce_index_calculator, left_index_calculator, dim);
x_data, config.output_data, reducer, transform, init, config.reduce_num,
config.left_num, config.reduce_last_dim, reduce_index_calculator,
left_index_calculator, dim);
#else
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
OneDimIndexCal><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
init, config.reduce_num, config.left_num, config.reduce_last_dim,
reduce_index_calculator, left_index_calculator, dim);
x_data, config.output_data, reducer, transform, init, config.reduce_num,
config.left_num, config.reduce_last_dim, reduce_index_calculator,
left_index_calculator, dim);
#endif
} else {
......@@ -771,15 +771,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
IndexCalculator><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
init, config.reduce_num, config.left_num, config.reduce_last_dim,
reduce_index_calculator, left_index_calculator, dim);
x_data, config.output_data, reducer, transform, init, config.reduce_num,
config.left_num, config.reduce_last_dim, reduce_index_calculator,
left_index_calculator, dim);
#else
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
IndexCalculator><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
init, config.reduce_num, config.left_num, config.reduce_last_dim,
reduce_index_calculator, left_index_calculator, dim);
x_data, config.output_data, reducer, transform, init, config.reduce_num,
config.left_num, config.reduce_last_dim, reduce_index_calculator,
left_index_calculator, dim);
#endif
}
......@@ -802,23 +802,22 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel<Ty, Ty, MPType, ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, dim);
config.output_data, y_data, reducer, kps::IdentityFunctor<Ty, MPType>(),
init, config.grid.y, config.left_num, config.grid.y, dim);
#else
ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, dim);
config.output_data, y_data, reducer, kps::IdentityFunctor<Ty, MPType>(),
init, config.grid.y, config.left_num, config.grid.y, dim);
#endif
}
}
template <typename Tx, typename Ty,
template <typename, typename> class ReduceOp>
template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform,
std::vector<int> origin_reduce_dims,
gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims());
......@@ -853,10 +852,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
(!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) {
// launch CUB::Reduce
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, TransformOp(config.reduce_num));
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
......@@ -873,7 +871,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}
using MPType = typename details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>();
auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
......@@ -882,7 +880,6 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if (config.reduce_type == ReduceType::kReduceHigherDim) {
using TransformOp = typename ReduceOp<Tx, MPType>::Transformer;
kps::DimConfig dim =
kps::DimConfig(config.grid.x, config.grid.y, config.grid.z,
config.block.x, config.blocking_size, 0);
......@@ -890,18 +887,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
config.reduce_num % config.blocking_size, 0);
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>,
TransformOp><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
reducer.initial(), config.reduce_num, config.left_num,
config.blocking_size, dim);
x_data, config.output_data, reducer, transform, reducer.initial(),
config.reduce_num, config.left_num, config.blocking_size, dim);
#else
ReduceHigherDimKernel<
Tx, Ty, MPType, ReduceOp<Tx, MPType>,
Tx, Ty, MPType, ReduceOp<MPType>,
TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
reducer.initial(), config.reduce_num, config.left_num,
config.blocking_size, dim);
x_data, config.output_data, reducer, transform, reducer.initial(),
config.reduce_num, config.left_num, config.blocking_size, dim);
#endif
if (config.should_reduce_again) {
......@@ -913,14 +908,14 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
#ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp<Tx, MPType>,
Ty, Ty, MPType, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
config.grid.y, config.left_num, config.grid.y, dim2);
#else
ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp<Tx, MPType>,
Ty, Ty, MPType, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
......@@ -933,23 +928,32 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
x_data, y_data, reducer, reducer.initial(), stream, config);
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>(
x_data, y_data, reducer, transform, reducer.initial(), stream, config);
}
template <typename Tx, template <typename, typename> class ReduceOp>
template <typename Tx, template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
struct TensorReduceFunc {
const framework::Tensor& x;
framework::Tensor* y;
std::vector<int> origin_reduce_dims;
gpuStream_t stream;
int reduce_num;
TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, gpuStream_t stream)
: x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {}
std::vector<int> origin_reduce_dims, int num_reduce,
gpuStream_t stream)
: x(x),
y(y),
origin_reduce_dims(origin_reduce_dims),
reduce_num(num_reduce),
stream(stream) {}
template <typename Ty>
void apply() const {
TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream);
using MPType = typename details::MPTypeTrait<Ty>::Type;
TensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp<Tx, MPType>>(
x, y, TransformOp<Tx, MPType>(reduce_num), origin_reduce_dims, stream);
}
};
......
......@@ -670,7 +670,8 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
};
#if defined(__HIPCC__) || defined(__NVCC__)
template <typename T, template <typename, typename> class ReduceOp>
template <typename T, template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
class ReduceCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -682,15 +683,19 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1;
for (int i = 0; i < input->dims().size(); i++) {
reduce_num *= (input->dims())[i];
}
gpuStream_t stream = context.cuda_device_context().stream();
if (out_dtype >= 0) {
framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunc<T, ReduceOp>(*input, output, reduce_dims, stream));
TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream));
} else {
TensorReduceFunctorImpl<T, T, ReduceOp>(*input, output, reduce_dims,
stream);
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>(
*input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream);
}
}
};
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_prod, ops::ReduceCudaKernel<float, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMul>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMul>);
reduce_prod,
ops::ReduceCudaKernel<float, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MulFunctor, kps::IdentityFunctor>);
......@@ -11,18 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_sum, ops::ReduceCudaKernel<bool, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<float, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<double, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<paddle::platform::float16,
paddle::operators::CustomSum>,
ops::ReduceCudaKernel<int, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<paddle::platform::complex<float>,
paddle::operators::CustomSum>,
ops::ReduceCudaKernel<paddle::platform::complex<double>,
paddle::operators::CustomSum>);
reduce_sum,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<float, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::complex<float>, kps::AddFunctor,
kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::complex<double>, kps::AddFunctor,
kps::IdentityFunctor>);
......@@ -15,21 +15,12 @@
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/trace_op.h"
namespace paddle {
namespace operators {
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
class TraceCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -48,9 +39,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto stream = context.cuda_device_context().stream();
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
diag, out, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
diag, out, kps::IdentityFunctor<T>(), reduce_dims, stream);
} else {
math::SetConstant<DeviceContext, T> functor;
functor(context.device_context<DeviceContext>(), out, static_cast<T>(0));
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
......@@ -44,7 +43,8 @@ struct MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
}
}
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(in, out, out_reduce_dims, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
in, out, kps::IdentityFunctor<T>(), out_reduce_dims, stream);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册