未验证 提交 f9e9fd19 编写于 作者: C chentianyu03 提交者: GitHub

[Pten] Add reduce mean kernel, replace with mean API (#37559)

* add pten reduce kernel

* add reduce_sum kernel

* update attribute args and order

* make out dtype undefined

* fix empty input error

* merge develop branch

* rename sum as reduce function

* rename sum as reduce function

* fix reducekernelImpl args error

* add reduce cuda kernel

* modify dims type to const &

* remove unsed log

* fix reduce_all out eigen function error

* remove unused codes

* add the missing sum api define and testcase

* merge develop branch

* fix sum test axis value error

* replace pten mean kernel with reduce_mean

* revcover meam cuda to original implement
上级 dae4e7f2
......@@ -25,6 +25,17 @@ namespace cub = hipcub;
namespace paddle {
namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -34,6 +45,37 @@ __global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
}
}
template <typename DeviceContext, typename T>
class MeanCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto size_prob = input->numel();
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
DivideFunctor<T> transformer(size_prob);
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
in_data, transformer);
size_t temp_storage_bytes = 0;
auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
context.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
}
};
template <typename DeviceContext, typename T>
class MeanCUDAGradKernel : public framework::OpKernel<T> {
public:
......@@ -62,11 +104,10 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, plat::float16>);
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -15,12 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace paddle {
namespace operators {
......@@ -33,40 +27,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
/** [ Why still keep the original kernel implementation? ]
*
* Removal of the original kernel implementation and kernel registration needs
* to ensure that the new kernel mechanism adapts to multiple sets of execution
* mechanisms, including:
*
* 1. Executor and ParallelExecutor
* 2. Dygraph OpBase (Tracer and Engine)
* 3. New Executor
* 4. Predictor
* 5. NPU and XPU lack kernel and need to reuse CPU Kernel
*
* Removal of the original Kernel requires a more complete solution to ensure
* that it will not affect the current execution system.
* Currently, only the first two cases are adapted.
*
* The principle here is that the implementation in the kernel must reuse the
* corresponding functions in the Tensor Operation library and cannot maintain
* two copies of the code.
*/
template <typename DeviceContext, typename T>
class MeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data<T>(x->place());
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
// call new kernel
pten::Mean<T>(dev_ctx, *pt_x.get(), pt_out.get());
y.device(place) = X.mean();
}
};
......
......@@ -23,6 +23,13 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/functions/general/reduce_impl.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
......@@ -232,43 +239,29 @@ class ReduceKernel : public framework::OpKernel<T> {
bool keep_dim = context.Attr<bool>("keep_dim");
int out_dtype = context.Attr<int>("out_dtype");
framework::proto::VarType::Type cast_out_dtype;
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
auto* input = context.Input<Tensor>("X");
if (out_dtype < 0) {
auto* cast_input = context.Input<Tensor>("X");
cast_out_dtype =
static_cast<framework::proto::VarType::Type>(cast_input->type());
framework::VisitDataType(
cast_out_dtype,
ReduceKernelFunctor<DeviceContext, T, Functor>(
cast_input, output, dims, keep_dim, reduce_all, context));
static_cast<framework::proto::VarType::Type>(input->type());
} else {
Tensor tmp_tensor;
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
auto* input = context.Input<Tensor>("X");
tmp_tensor.Resize(input->dims());
framework::VisitDataType(
cast_out_dtype,
CastOpFunctor<DeviceContext, T>(
input, &tmp_tensor,
context.template device_context<DeviceContext>()));
framework::VisitDataType(
cast_out_dtype,
ReduceKernelFunctor<DeviceContext, T, Functor>(
&tmp_tensor, output, dims, keep_dim, reduce_all, context));
}
auto& dev_ctx = context.device_context<DeviceContext>();
output->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(cast_out_dtype));
auto pt_x = paddle::experimental::MakePtenDenseTensor(*input);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*output);
std::vector<int64_t> tmp_dims(dims.begin(), dims.end());
// call new kernel
pten::general::Reduce<DeviceContext, T, Functor>(
dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
}
};
template <typename DeviceContext, typename OutT, typename Functor>
......
......@@ -21,7 +21,9 @@ namespace experimental {
// TODO(chenweihang): add scale API
// TODO(chenweihang): move mean API into stat.h/cc
PD_DLL_DECL Tensor mean(const Tensor& x);
PD_DLL_DECL Tensor mean(const Tensor& x,
const std::vector<int64_t>& axis,
bool keep_dim);
PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);
......@@ -31,5 +33,10 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y);
PD_DLL_DECL Tensor sum(const Tensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim);
} // namespace experimental
} // namespace paddle
......@@ -35,12 +35,14 @@ PT_DECLARE_MODULE(MathCUDA);
namespace paddle {
namespace experimental {
PD_DLL_DECL Tensor mean(const Tensor& x) {
PD_DLL_DECL Tensor mean(const Tensor& x,
const std::vector<int64_t>& axis,
bool keep_dim) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"mean", kernel_key);
"reduce_mean", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
......@@ -50,8 +52,73 @@ PD_DLL_DECL Tensor mean(const Tensor& x) {
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;
DataType out_dtype = DataType::UNDEFINED;
kernel_context.EmplaceBackAttr(axis);
kernel_context.EmplaceBackAttr(keep_dim);
kernel_context.EmplaceBackAttr(reduce_all);
kernel_context.EmplaceBackAttr(dense_x->dtype());
kernel_context.EmplaceBackAttr(out_dtype);
// 4. InferShape
auto out_meta = ReduceInferMeta(dense_x->meta(), axis, keep_dim);
// 5. Prepare outputs
Tensor out;
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
PD_DLL_DECL Tensor sum(const Tensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"reduce_sum", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;
DataType out_dtype = DataType::UNDEFINED;
if (dense_x->dtype() == DataType::BOOL ||
dense_x->dtype() == DataType::INT32 ||
dense_x->dtype() == DataType::INT64) {
out_dtype = DataType::INT64;
}
kernel_context.EmplaceBackAttr(axis);
kernel_context.EmplaceBackAttr(keep_dim);
kernel_context.EmplaceBackAttr(reduce_all);
kernel_context.EmplaceBackAttr(dense_x->dtype());
kernel_context.EmplaceBackAttr(out_dtype);
// 4. InferMeta
auto out_meta = ReductionInferMeta(dense_x->meta());
auto out_meta = ReduceInferMeta(dense_x->meta(), axis, keep_dim);
// 5. Prepare outputs
Tensor out;
......
......@@ -34,13 +34,44 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) {
}
template <typename T, typename ContextT>
DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = ReductionInferMeta(x.meta());
DenseTensor Mean(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
Mean<T>(dev_ctx, x, &dense_out);
bool reduce_all = false;
DataType out_dtype = pten::DataType::UNDEFINED;
Mean<T>(
dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Sum(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;
if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 ||
x.dtype() == pten::DataType::INT64) {
dtype = pten::DataType::INT64;
}
Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out);
return dense_out;
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/unary.h"
#include <set>
namespace pten {
......@@ -226,4 +227,50 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
return return_meta;
}
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x_meta.dims.size(); ++i) {
if (dims_set.find(i) == dims_set.end()) {
reduce_all = false;
break;
}
}
std::vector<int64_t> out_dim_vector;
if (keep_dim) {
for (int64_t i = 0; i < x_meta.dims.size(); ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x_meta.dims.at(i));
}
}
} else {
for (int64_t i = 0; i < x_meta.dims.size(); ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue;
} else {
out_dim_vector.push_back(x_meta.dims.at(i));
}
}
if (out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}
}
DDim out_dim = paddle::framework::make_ddim(out_dim_vector);
DataType out_dtype = x_meta.dtype;
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 ||
x_meta.dtype == DataType::INT64) {
out_dtype = DataType::INT64;
}
DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout);
return return_meta;
}
} // namespace pten
......@@ -49,4 +49,8 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& shape);
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim);
} // namespace pten
# pten math functions called by kernels
add_subdirectory(math)
# pten basic functions called by kernels
add_subdirectory(functions)
# pten kernels for diff device
......
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas)
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu)
cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory)
cc_library(creation_cpu SRCS creation.cc DEPS dense_tensor kernel_context kernel_factory eigen_function)
cc_library(utils_cpu SRCS utils.cc DEPS dense_tensor kernel_context kernel_factory memory convert_utils)
......
......@@ -14,11 +14,13 @@
#include "paddle/pten/kernels/cpu/math.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/kernels/functions/cpu/elementwise.h"
#include "paddle/pten/kernels/functions/eigen/mean.h"
#include "paddle/pten/kernels/functions/eigen/reduce.h"
#include "paddle/pten/kernels/functions/eigen/scale.h"
#include "paddle/pten/kernels/functions/eigen/sign.h"
#include "paddle/pten/kernels/functions/general/elementwise_functor.h"
#include "paddle/pten/kernels/functions/general/reduce_impl.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
......@@ -33,8 +35,16 @@ void Sign(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out) {
}
template <typename T>
void Mean(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out) {
eigen::Mean<CPUContext, T>(dev_ctx, x, out);
void Mean(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::general::Reduce<CPUContext, T, pten::eigen::MeanFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
template <typename T>
......@@ -88,6 +98,19 @@ void ElementwiseDiv(const CPUContext& dev_ctx,
}
}
template <typename T>
void Sum(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::general::Reduce<CPUContext, T, pten::eigen::SumFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
// Create the definition of ElementwiseAdd
DEFINE_CPU_ELEMENTWISE_OP(Add)
......@@ -109,8 +132,7 @@ using complex128 = ::paddle::platform::complex<double>;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(
"mean", CPU, ANY, pten::Mean, float, double, paddle::platform::bfloat16) {}
PT_REGISTER_KERNEL("reduce_mean", CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
CPU,
ANY,
......@@ -178,3 +200,18 @@ PT_REGISTER_KERNEL("elementwise_mul",
bool,
complex64,
complex128) {}
PT_REGISTER_KERNEL("reduce_sum",
CPU,
ANY,
pten::Sum,
bool,
float,
double,
paddle::platform::float16,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -28,7 +28,14 @@ template <typename T>
void Sign(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T>
void Mean(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out);
void Mean(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
template <typename T>
void Scale(const CPUContext& dev_ctx,
......@@ -73,6 +80,16 @@ void ElementwiseMul(const CPUContext& dev_ctx,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T>
void Sum(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
} // namespace pten
#define DEFINE_CPU_ELEMENTWISE_OP(name) \
......
if(WITH_GPU)
nv_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory)
nv_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda)
nv_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
nv_library(creation_cuda SRCS creation.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
nv_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils)
nv_library(manipulation_cuda SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_cuda unary)
elseif(WITH_ROCM)
hip_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory)
hip_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda)
hip_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
hip_library(creation_cuda SRCS creation.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
hip_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils)
......
......@@ -14,11 +14,13 @@ limitations under the License. */
#include "paddle/pten/kernels/cuda/math.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/functions/eigen/mean.h"
#include "paddle/pten/kernels/functions/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/functions/eigen/scale.h"
#include "paddle/pten/kernels/functions/eigen/sign.h"
#include "paddle/pten/kernels/functions/general/elementwise_functor.h"
#include "paddle/pten/kernels/functions/general/reduce_impl.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -62,37 +64,16 @@ void Sign(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) {
}
template <typename T>
void Mean(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) {
auto size_prob = x.numel();
const T* x_data = x.data<T>();
T* out_data = out->mutable_data<T>();
auto stream = dev_ctx.stream();
DivideFunctor<T> transformer(size_prob);
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
x_data, transformer);
size_t temp_storage_bytes = 0;
auto err = cub::DeviceReduce::Sum(
nullptr, temp_storage_bytes, trans_x, out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor tmp(
alloc,
DenseTensorMeta(x.dtype(),
paddle::framework::make_ddim(
{static_cast<int64_t>(temp_storage_bytes)}),
x.layout()));
void* temp_storage = tmp.mutable_data<T>();
err = cub::DeviceReduce::Sum(static_cast<uint8_t*>(temp_storage),
temp_storage_bytes,
trans_x,
out_data,
size_prob,
stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
void Mean(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<T, paddle::operators::CustomMean>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
template <typename T>
......@@ -133,6 +114,19 @@ DEFINE_CUDA_ELEMENTWISE_OP(Mul)
// Create the definition of ElementwiseDiv
DEFINE_CUDA_ELEMENTWISE_OP(Div)
template <typename T>
void Sum(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<T, paddle::operators::CustomSum>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace pten
// TODO(chenweihang): replace by better impl
......@@ -143,7 +137,7 @@ using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, float16) {}
PT_REGISTER_KERNEL("reduce_mean", CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
CUDA,
ANY,
......@@ -215,3 +209,17 @@ PT_REGISTER_KERNEL("elementwise_mul",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("reduce_sum",
CUDA,
ANY,
pten::Sum,
bool,
float,
double,
float16,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -30,7 +30,14 @@ template <typename T>
void Sign(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T>
void Mean(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out);
void Mean(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
template <typename T>
void Scale(const CUDAContext& dev_ctx,
......@@ -75,6 +82,17 @@ void ElementwiseMul(const CUDAContext& dev_ctx,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T>
void Sum(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
} // namespace pten
#define DEFINE_CUDA_ELEMENTWISE_OP(name) \
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/kernels/functions/cuda/reduce/reduce_cuda_impl.h"
namespace pten {
using CUDAContext = paddle::platform::CUDADeviceContext;
static inline std::vector<int64_t> GetReduceDim(
const std::vector<int64_t>& dims, int dim_size, bool reduce_all) {
std::vector<int64_t> reduce_dims;
if (reduce_all) {
reduce_dims.resize(dim_size);
int reduce_size = reduce_dims.size();
for (int i = 0; i < reduce_size; ++i) {
reduce_dims[i] = i;
}
} else {
for (auto e : dims) {
PADDLE_ENFORCE_LT(e,
dim_size,
paddle::platform::errors::InvalidArgument(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should less than x_dims, but got %d.",
dim_size,
e));
reduce_dims.push_back(e >= 0 ? e : e + dim_size);
}
}
return reduce_dims;
}
template <typename T, template <typename, typename> class ReduceFunctor>
void Reduce(const CUDAContext& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
std::vector<int64_t> reduce_dims =
GetReduceDim(dims, x.dims().size(), reduce_all);
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED) {
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
out_dtype, "TensorReduceFunctorImpl", ([&] {
pten::detail::TensorReduceFunctorImpl<T, data_t, ReduceFunctor>(
x, out, reduce_dims, stream);
}));
} else {
pten::detail::TensorReduceFunctorImpl<T, T, ReduceFunctor>(
x, out, reduce_dims, stream);
}
}
} // namespace pten
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/kernels/cuda/utils.h"
#include "paddle/pten/kernels/functions/math/cast_func.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
namespace pten {
namespace detail {
namespace kps = paddle::operators::kernel_primitives;
namespace details {
static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; }
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int64_t> GetDimStrides(
const std::vector<int64_t>& dims, const std::vector<int64_t>& idx) {
int n = static_cast<int>(idx.size());
if (n == 0) return std::vector<int64_t>();
std::vector<int64_t> strides(n);
strides.back() = 1;
for (int i = n - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dims[idx[i + 1]];
}
return strides;
}
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
return block_dim >= kps::details::kReduceMaxThread
? kps::details::kReduceMaxThread
: GetLastPow2(block_dim);
}
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
if (rank % 2 == 0) {
PADDLE_ENFORCE_EQ(reduce_rank,
rank / 2,
paddle::platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, "
"reduce_rank must be %d, but got %d.",
rank,
rank / 2,
reduce_rank));
} else {
auto lower_rank = (rank - 1) / 2;
auto upper_rank = (rank + 1) / 2;
PADDLE_ENFORCE_EQ(
reduce_rank == lower_rank || reduce_rank == upper_rank,
true,
paddle::platform::errors::InvalidArgument(
"ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
"must be %d or %d, but got %d.",
rank,
lower_rank,
upper_rank,
reduce_rank));
}
}
// convert dims from vector to array
template <typename T, size_t ElementCount, typename VectorLikeType>
static inline paddle::framework::Array<T, ElementCount> VectorToArray(
const VectorLikeType& vec) {
PADDLE_ENFORCE_LE(vec.size(),
ElementCount,
paddle::platform::errors::InvalidArgument(
"Cub reduce Array: size not match. Received "
"vec.size() %d > ElementCount %d.",
vec.size(),
ElementCount));
size_t n = static_cast<size_t>(vec.size());
paddle::framework::Array<T, ElementCount> ret;
for (size_t i = 0; i < n; ++i) {
ret[i] = vec[i];
}
return ret;
}
} // namespace details
constexpr int kMaxRank = pten::DDim::kMaxRank;
enum ReduceType {
kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1;
kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim
kReduceAny = 0x03, // when reduce_dim.size() > 1
};
struct IndexCalculator {
IndexCalculator(int dim,
const std::vector<int64_t>& cal_dims,
const std::vector<int64_t>& cal_strides,
const std::vector<int64_t>& full_strides)
: dim(dim) {
dims = details::VectorToArray<int, kMaxRank>(cal_dims);
strides = details::VectorToArray<int, kMaxRank>(full_strides);
std::vector<paddle::platform::FastDivMod> cal_divmoders;
// fast divmod
for (auto i : cal_strides) {
cal_divmoders.push_back(paddle::platform::FastDivMod(i));
}
divmoders = details::VectorToArray<paddle::platform::FastDivMod, kMaxRank>(
cal_divmoders);
}
__device__ inline int operator()(int offset) const {
int index = 0;
#pragma unroll
for (int i = 0; i < kMaxRank; ++i) {
if (i == dim) {
break;
}
auto divmod = divmoders[i].Divmod(offset);
index += (divmod.val[0] * strides[dims[i]]);
offset = divmod.val[1];
}
return index;
}
int dim;
paddle::framework::Array<int, kMaxRank> dims;
paddle::framework::Array<int, kMaxRank> strides;
paddle::framework::Array<paddle::platform::FastDivMod, kMaxRank> divmoders;
};
template <bool ReduceLastDim = false>
struct ReduceIndexMapping {
const kps::DimConfig dim;
HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims)
: dim(dims) {}
__device__ __forceinline__ int BlockIdX() {
#ifdef PADDLE_WITH_XPU2
if (ReduceLastDim) {
return (cluster_id() / dim.split_num_x % dim.split_num_y);
} else {
return cluster_id() % dim.split_num_x;
}
#else
return blockIdx.x;
#endif
}
__device__ __forceinline__ int BlockIdY() {
#ifdef PADDLE_WITH_XPU2
if (ReduceLastDim) {
return (cluster_id() % dim.split_num_x);
} else {
return (cluster_id() / dim.split_num_x % dim.split_num_y);
}
#else
return blockIdx.y;
#endif
}
__device__ __forceinline__ int BlockDimX() {
#ifdef PADDLE_WITH_XPU2
return dim.deal_size_x;
#else
return blockDim.x;
#endif
}
__device__ __forceinline__ int BlockDimY() {
#ifdef PADDLE_WITH_XPU2
return dim.deal_size_y;
#else
return blockDim.y;
#endif
}
__device__ __forceinline__ int GridDimX() {
#ifdef PADDLE_WITH_XPU2
if (ReduceLastDim) {
return dim.split_num_y;
} else {
return dim.split_num_x;
}
#else
return gridDim.x;
#endif
}
__device__ __forceinline__ int GridDimY() {
#ifdef PADDLE_WITH_XPU2
if (ReduceLastDim) {
return dim.split_num_x;
} else {
return dim.split_num_y;
}
#else
return gridDim.y;
#endif
}
__device__ __forceinline__ int GetLoopSize() {
#ifdef PADDLE_WITH_XPU2
if (ReduceLastDim) {
return dim.deal_size_y;
} else {
return dim.deal_size_x;
}
#else
return 1;
#endif
}
};
// when reduce_type == kReduceLastDim this struct will be used
// for higher performance
struct OneDimIndexCal {
explicit OneDimIndexCal(int num) : stride(num) {}
__device__ inline int operator()(int index) const { return index * stride; }
int stride;
};
// reduce config
template <typename Ty>
struct ReduceConfig {
ReduceConfig(const std::vector<int64_t>& origin_reduce_dims,
const std::vector<int64_t>& origin_x_dim)
: reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
// get the parameters of reduceKernel
void Run() {
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim();
// step2: get the strides of dim for reduceAny and reduceLastDim
SetStrides();
// step3: get the type of reduce
SetReduceType();
// step4: set the block and grid for launch kernel
SetBlockDim();
}
// when should_reduce_again is true, we need malloc temp space for temp data
void SetOutputData(Ty* y_data,
const paddle::platform::Place& place,
pten::DenseTensor* tmp) {
if (should_reduce_again) {
tmp->Resize(paddle::framework::make_ddim(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}));
output_data = tmp->mutable_data<Ty>();
} else {
output_data = y_data;
}
}
private:
// set reduce_dim, left_dim and update x_dim
// eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1]
// --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
void SetReduceDim() {
std::set<int64_t> reduce_set;
for (auto e : reduce_dims_origin) {
auto pos = e >= 0 ? e : e + x_dim.size();
reduce_set.insert(pos);
}
std::vector<int64_t> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
// update reduce_dim and x_dim
std::vector<int64_t> x_new_dim;
reduce_dim.push_back(reduce_dim_temp[0]);
x_new_dim.push_back(x_dim[0]);
int idx_reduce = 1;
int num = 0;
if (reduce_dim_temp.size() > 1) {
for (int i = 1; i < x_dim.size(); i++) {
if ((idx_reduce < reduce_dim_temp.size()) &&
(i == reduce_dim_temp[idx_reduce])) {
int result =
reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
bool is_equal = ((result - num) == 1);
if (is_equal) {
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
num++;
} else {
reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
x_new_dim.push_back(x_dim[i]);
}
idx_reduce++;
} else {
x_new_dim.push_back(x_dim[i]);
}
}
} else {
x_new_dim = x_dim;
}
// update x_dim
x_dim = x_new_dim;
std::vector<int64_t>().swap(x_new_dim);
std::vector<int64_t> reduce_dim_new;
int is_reduced = 0;
for (auto e : reduce_dim) {
is_reduced |= 1 << e;
}
std::vector<int64_t>().swap(reduce_dim);
for (int i = 0; i < x_dim.size(); i++) {
if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
x_new_dim.push_back(x_dim[i]);
if ((is_reduced >> i) & 1)
reduce_dim_new.push_back(x_new_dim.size() - 1);
} else {
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
}
}
x_dim = x_new_dim;
reduce_dim = reduce_dim_new;
int x_rank = static_cast<int>(x_dim.size());
std::set<int> left_set;
for (int i = 0; i < x_rank; ++i) {
left_set.insert(i);
}
for (auto e : reduce_dim) {
left_set.erase(e);
}
left_dim.assign(left_set.begin(), left_set.end());
// if the last dim gets involved in reduction
reduce_last_dim = (reduce_dim.back() == x_dim.size() - 1);
}
// set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
// eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1]
// --SetStrides--> x_strides= [6,1], reduce_strides = [1],
// left_strides = [1]
void SetStrides() {
std::vector<int64_t> idx_dim;
for (int i = 0; i < x_dim.size(); i++) {
idx_dim.push_back(i);
}
x_strides = details::GetDimStrides(x_dim, idx_dim);
reduce_strides = details::GetDimStrides(x_dim, reduce_dim);
left_strides = details::GetDimStrides(x_dim, left_dim);
reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];
left_num = 1;
if (left_dim.size()) {
left_num = left_strides[0] * x_dim[left_dim[0]];
}
}
// get the reduceType
// eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim
// x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim
// x_dim = [8] reduce_dim = [0] --> reduceAll
// x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny
void SetReduceType() {
int rank = x_dim.size();
int reduce_rank = reduce_dim.size();
bool is_last_dim =
(rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
if (rank == reduce_rank || is_last_dim) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1) {
// ReduceFirstDim and reduceSecondDim
#ifdef PADDLE_WITH_XPU2
if (reduce_dim[0] == 0) {
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
} else {
reduce_type = static_cast<int>(ReduceType::kReduceAny);
}
#else
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
#endif
} else {
reduce_type = static_cast<int>(ReduceType::kReduceAny);
}
}
void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) {
constexpr int min_reduce_num_per_thread = 16;
constexpr int max_reduce_num_per_thread = 256;
constexpr int max_num_threads = kps::details::kReduceMaxThread;
// set block size.
// 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same
// will process the reduction for one output.
// The number of output for one block is blockDim.y;
// 2. If reduce_last_dim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is
// necessary, it should reduce in block y.
// The number of output for one block is blockDim.x;
int block_x, block_y;
int grid_num, reduce_num_per_thread;
if (reduce_last_dim) {
block_x = details::GetBlockDim(reduce_num);
block_y = details::GetBlockDim(left_num);
block_dim->x = block_x;
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
grid_num = details::AlignUp(left_num, block_dim->y);
reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->x);
} else {
block_x = details::GetBlockDim(left_num);
block_y = details::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32);
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
block_dim->x =
std::min(block_x, static_cast<int>(max_num_threads / block_dim->y));
grid_num = details::AlignUp(left_num, block_dim->x);
reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->y);
}
int device_id = paddle::platform::GetCurrentDeviceId();
int max_mp = paddle::platform::GetCUDAMultiProcessors(device_id);
int max_threads_per_mp =
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block_dim->x * block_dim->y;
int max_num_blocks = max_threads / num_threads;
// set grid size.
// Whether to set grid.y larger than 1, there are 3 following rules:
// 1. The number that each thread process should no less than
// min_reduce_num_per_threadbut no more than max_reduce_num_per_thread;
// 2. It should maximize the utilization of SM.
// So we choose the minimum between input_split_num_1 and input_split_num_3
// to make each thread process as mush data as possible. Meanwhile,
// the number cannot be larger than max_reduce_num_per_thread, so we
// choose the maximum between the result above and input_split_num_2.
int input_split_num_1 =
details::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread);
int input_split_num_2 =
details::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread);
int input_split_num_3 = details::AlignUp(max_num_blocks, grid_num);
grid_dim->x = grid_num;
grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3),
input_split_num_2);
// if grid.y > 1, we need launch reduce kernel again.
if (grid_dim->y > 1) {
should_reduce_again = true;
}
}
// set block and grid for launch kernel
// for ReduceHigherDim: if block is enough -> splite reduce_num
// else init block(32, 1) grid(block_num, 1)
// for others: block(block_num, 1) , grid(left_num, 1)
void SetBlockDimForHigher(dim3* block_dim, dim3* grid_dim) {
int last_dim_num = x_dim.back();
// update left_num
int grid_z = left_num / last_dim_num;
left_num = last_dim_num;
grid_dim->z = grid_z;
int device_id = paddle::platform::GetCurrentDeviceId();
int max_mp = paddle::platform::GetCUDAMultiProcessors(device_id);
int max_threads_per_mp =
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
// init
int num_block = (max_threads / left_num);
block_dim->x = details::GetBlockDim(left_num);
grid_dim->x = details::AlignUp(left_num, block_dim->x);
blocking_size = reduce_num;
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
blocking_size = details::GetLastPow2(reduce_num / num_block);
if (blocking_size <= 1) {
blocking_size = details::GetLastPow2(sqrt(reduce_num));
} else if (blocking_size * 2 < reduce_num) {
blocking_size *= 2;
}
should_reduce_again = true;
grid_dim->y = details::AlignUp(reduce_num, blocking_size);
}
}
void SetBlockDim() {
// init
int block_num = details::GetBlockDim(reduce_num);
should_reduce_again = false;
dim3 block_dim(block_num, 1, 1);
dim3 grid_dim(left_num, 1, 1);
blocking_size = reduce_num;
#ifdef PADDLE_WITH_XPU2
if (reduce_last_dim) {
block_dim.x = 128;
block_dim.y = reduce_num;
grid_dim.x = 8;
grid_dim.y = 1;
} else {
block_dim.x = 128;
block_dim.y = left_num;
grid_dim.x = 8;
grid_dim.y = 1;
}
#else
if (reduce_type == ReduceType::kReduceHigherDim) {
SetBlockDimForHigher(&block_dim, &grid_dim);
} else {
SetBlockDimForReduceAny(&block_dim, &grid_dim);
}
#endif
block = block_dim;
grid = grid_dim;
}
public:
std::vector<int64_t> reduce_dims_origin;
std::vector<int64_t> reduce_dim;
std::vector<int64_t> x_dim;
std::vector<int64_t> left_dim;
std::vector<int64_t> x_strides;
std::vector<int64_t> left_strides;
std::vector<int64_t> reduce_strides;
int reduce_type;
int reduce_num;
int left_num;
int blocking_size;
bool should_reduce_again;
bool reduce_last_dim;
Ty* output_data;
dim3 block;
dim3 grid;
};
template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
static void LaunchReduceKernel(const Tx* x_data,
Ty* y_data,
const ReduceOp& reducer,
MPType init,
gpuStream_t stream,
ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer;
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
int stride_left = config.reduce_num;
// for higher performance
auto reduce_index_calculator = OneDimIndexCal(stride_reduce);
auto left_index_calculator = OneDimIndexCal(stride_left);
kps::DimConfig dim = kps::DimConfig(config.grid.x,
config.grid.y,
config.grid.z,
config.block.x,
config.block.y,
0);
dim.SetRem(config.reduce_num % config.block.x, 0, 0);
#ifdef PADDLE_WITH_XPU2
paddle::operators::ReduceAnyKernel<Tx,
Ty,
MPType,
ReduceOp,
TransformOp,
OneDimIndexCal><<<8, 128, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
init,
config.reduce_num,
config.left_num,
config.reduce_last_dim,
reduce_index_calculator,
left_index_calculator,
dim);
#else
paddle::operators::ReduceAnyKernel<
Tx,
Ty,
MPType,
ReduceOp,
TransformOp,
OneDimIndexCal><<<config.grid, config.block, 0, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
init,
config.reduce_num,
config.left_num,
config.reduce_last_dim,
reduce_index_calculator,
left_index_calculator,
dim);
#endif
} else {
int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
auto reduce_index_calculator = IndexCalculator(reduce_rank,
config.reduce_dim,
config.reduce_strides,
config.x_strides);
auto left_index_calculator = IndexCalculator(
left_rank, config.left_dim, config.left_strides, config.x_strides);
kps::DimConfig dim = kps::DimConfig(config.grid.x,
config.grid.y,
config.grid.z,
config.block.x,
config.block.y,
0);
dim.SetRem(config.reduce_num % config.block.x, 0, 0);
#ifdef PADDLE_WITH_XPU2
paddle::operators::ReduceAnyKernel<Tx,
Ty,
MPType,
ReduceOp,
TransformOp,
IndexCalculator><<<8, 128, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
init,
config.reduce_num,
config.left_num,
config.reduce_last_dim,
reduce_index_calculator,
left_index_calculator,
dim);
#else
paddle::operators::ReduceAnyKernel<
Tx,
Ty,
MPType,
ReduceOp,
TransformOp,
IndexCalculator><<<config.grid, config.block, 0, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
init,
config.reduce_num,
config.left_num,
config.reduce_last_dim,
reduce_index_calculator,
left_index_calculator,
dim);
#endif
}
if (config.should_reduce_again) {
dim3 block;
dim3 grid;
if (config.reduce_last_dim) {
block = dim3(32, 1, 1);
grid = dim3(details::AlignUp(config.left_num, 32), 1, 1);
} else {
block = dim3(config.block.x, 1, 1);
grid = dim3(config.grid.x, 1, config.grid.z);
}
auto last_index = OneDimIndexCal(1);
auto first_index = OneDimIndexCal(config.left_num);
kps::DimConfig dim =
kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0);
dim.SetRem(config.left_num % block.x, 0, 0);
#ifdef PADDLE_WITH_XPU2
paddle::operators::ReduceHigherDimKernel<
Ty,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim);
#else
paddle::operators::ReduceHigherDimKernel<
Ty,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim);
#endif
}
}
template <typename Tx,
typename Ty,
template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const pten::DenseTensor& x,
pten::DenseTensor* y,
std::vector<int64_t> origin_reduce_dims,
gpuStream_t stream) {
// Allocate memory
y->mutable_data<Ty>();
auto x_dim = paddle::framework::vectorize<int64_t>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run();
int64_t numel = x.numel();
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
// y_data;
pten::DDim tmp_ddim;
const auto alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(y->place());
pten::DenseTensor tmp = pten::DenseTensor(
alloc, pten::DenseTensorMeta(y->dtype(), tmp_ddim, y->layout()));
auto x_data = x.data<Tx>();
auto y_data = y->mutable_data<Ty>();
auto* dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(x.place()));
if (config.reduce_num == 1) {
auto out_dims = y->dims();
if (x.dtype() == y->dtype()) {
pten::Copy(*dev_ctx, x, true, y);
y->Resize(out_dims);
} else {
PD_VISIT_ALL_TYPES(y->dtype(), "CastKernelImpl", ([&] {
pten::math::CastKernelImpl<CUDAContext, Tx, data_t>(
*dev_ctx, x, y);
}));
}
return;
}
config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.reduce_num == numel) &&
(!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) {
// launch CUB::Reduce
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, TransformOp(config.reduce_num));
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr,
temp_storage_bytes,
trans_x,
y_data,
config.reduce_num,
reducer,
reducer.initial(),
stream);
// framework::Tensor tmp;
const auto alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(x.place());
pten::DenseTensor tmp = pten::DenseTensor(
alloc,
pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim(
{static_cast<int64_t>(temp_storage_bytes)}),
x.layout()));
auto* temp_storage = tmp.mutable_data<uint8_t>();
cub::DeviceReduce::Reduce(temp_storage,
temp_storage_bytes,
trans_x,
y_data,
config.reduce_num,
reducer,
reducer.initial(),
stream);
return;
}
using MPType =
typename paddle::operators::kernel_primitives::details::MPTypeTrait<
Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if (config.reduce_type == ReduceType::kReduceHigherDim) {
using TransformOp = typename ReduceOp<Tx, MPType>::Transformer;
kps::DimConfig dim = kps::DimConfig(config.grid.x,
config.grid.y,
config.grid.z,
config.block.x,
config.blocking_size,
0);
dim.SetRem(config.left_num % config.block.x,
config.reduce_num % config.blocking_size,
0);
#ifdef PADDLE_WITH_XPU2
paddle::operators::ReduceHigherDimKernel<Tx,
Ty,
MPType,
ReduceOp<Tx, MPType>,
TransformOp><<<8, 128, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
reducer.initial(),
config.reduce_num,
config.left_num,
config.blocking_size,
dim);
#else
paddle::operators::ReduceHigherDimKernel<
Tx,
Ty,
MPType,
ReduceOp<Tx, MPType>,
TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data,
config.output_data,
reducer,
TransformOp(config.reduce_num),
reducer.initial(),
config.reduce_num,
config.left_num,
config.blocking_size,
dim);
#endif
if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
dim3 grid = dim3(config.grid.x, 1, config.grid.z);
kps::DimConfig dim2 =
kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0);
dim2.SetRem(config.left_num % config.block.x, 0, 0);
#ifdef PADDLE_WITH_XPU2
paddle::operators::ReduceHigherDimKernel<
Ty,
Ty,
MPType,
ReduceOp<Tx, MPType>,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2);
#else
paddle::operators::ReduceHigherDimKernel<
Ty,
Ty,
MPType,
ReduceOp<Tx, MPType>,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2);
#endif
}
return;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
x_data, y_data, reducer, reducer.initial(), stream, config);
}
} // namespace detail
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/eigen/eigen_function.h"
namespace pten {
namespace eigen {
template <typename DevCtx, typename T>
void Mean(const DevCtx& dev_ctx, const DenseTensor& x, DenseTensor* out) {
// TODO(chenweihang): if we design new tensor, we should support
// the low-level calc functor use new tensor as input,
// which may be a big project!
out->mutable_data<T>();
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_out = pten::EigenScalar<T>::From(*out);
auto& dev = *dev_ctx.eigen_device();
eigen_out.device(dev) = eigen_x.mean();
}
} // namespace eigen
} // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
#include "paddle/pten/kernels/math/transpose.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/eigen/eigen_function.h"
namespace pten {
namespace eigen {
template <typename DeviceContext,
typename T,
size_t D,
size_t R_D,
typename Functor>
void ReduceFunctor(const DeviceContext& context,
const pten::DenseTensor& input,
pten::DenseTensor* output,
const std::vector<int64_t>& dims,
bool keep_dim) {
auto x = EigenTensor<T, D>::From(input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto reduce_dim = Eigen::array<int, R_D>();
std::vector<int64_t> dims_ref = dims;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
reduce_dim[i] = dims_ref[i];
}
// construct the squeezed output tensor
DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
const int kDelFlag = -2;
auto dims_vector = paddle::framework::vectorize(out_dims);
for (size_t i = 0; i < dims_ref.size(); ++i) {
dims_vector[dims_ref[i]] = kDelFlag;
}
dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = paddle::framework::make_ddim(dims_vector);
}
auto& place = *context.eigen_device();
Functor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(*output);
functor(place, &x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(place, &x, &out, reduce_dim);
}
}
#define HANDLE_REDUCE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, Functor>( \
dev_ctx, input, output, dims, keep_dim); \
}
//////////////// HandleLargeDim
inline void GetShuffledDim(const DDim& src_dims,
DDim* dst_dims,
const std::vector<int64_t>& reduced_dims,
std::vector<int64_t>* perm_axis) {
// check if it's a reduced dim
std::vector<bool> src_dims_check(src_dims.size(), false);
size_t src_size = src_dims.size();
size_t reduce_size = reduced_dims.size();
for (size_t i = 0; i < reduce_size; ++i) {
dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]];
(*perm_axis)[src_size - reduce_size + i] = reduced_dims[i];
src_dims_check[reduced_dims[i]] = true;
}
size_t offset = 0;
for (size_t i = 0; i < src_dims_check.size(); ++i) {
bool is_reduced = src_dims_check[i];
if (!is_reduced) {
(*perm_axis)[offset] = i;
dst_dims->at(offset++) = src_dims[i];
}
}
}
template <typename DeviceContext, typename OutT>
void GetShuffledInput(const DeviceContext& dev_ctx,
const pten::DenseTensor& input,
pten::DenseTensor* shuffled_input,
const std::vector<int64_t>& dims) {
DDim shuffled_dims(input.dims());
std::vector<int64_t> perm_axis(input.dims().size());
GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis);
shuffled_input->Resize(shuffled_dims);
shuffled_input->mutable_data<OutT>();
pten::math::TransposeNormal<DeviceContext, OutT> trans;
trans(dev_ctx, input, shuffled_input, perm_axis);
}
template <typename DeviceContext, typename OutT, typename Functor>
void HandleLargeDim(const DeviceContext& dev_ctx,
const pten::DenseTensor& input,
pten::DenseTensor* output,
const std::vector<int64_t>& dims,
bool keep_dim) {
// shuffle the reduced dim to the end
const auto alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(input.place());
pten::DenseTensor shuffled_input = pten::DenseTensor(alloc, input.meta());
GetShuffledInput<DeviceContext, OutT>(dev_ctx, input, &shuffled_input, dims);
// transpose to 2D tensor whose shape is {unreduced, reduced}.
const int64_t unreduced = output->numel();
const int64_t reduced = shuffled_input.numel() / unreduced;
shuffled_input.Resize({unreduced, reduced});
DDim output_dim = output->dims();
output->Resize({unreduced});
ReduceFunctor<DeviceContext, OutT, 2, 1, Functor>(
dev_ctx, shuffled_input, output, {1}, keep_dim);
output->Resize(output_dim);
}
////////////// ReduceKernel
template <typename DeviceContext, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const DeviceContext& dev_ctx,
const pten::DenseTensor& input,
pten::DenseTensor* output,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all) {
output->mutable_data<OutT>();
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(input);
auto out = EigenScalar<OutT>::From(*output);
auto& dev = *dev_ctx.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
Functor functor;
functor(dev, &x, &out, reduce_dim);
} else {
int ndim = input.dims().size();
int rdim = dims.size();
if (ndim > 6) {
HandleLargeDim<DeviceContext, OutT, Functor>(
dev_ctx, input, output, dims, keep_dim);
} else {
HANDLE_REDUCE_DIM(6, 5);
HANDLE_REDUCE_DIM(6, 4);
HANDLE_REDUCE_DIM(6, 3);
HANDLE_REDUCE_DIM(6, 2);
HANDLE_REDUCE_DIM(6, 1);
HANDLE_REDUCE_DIM(5, 4);
HANDLE_REDUCE_DIM(5, 3);
HANDLE_REDUCE_DIM(5, 2);
HANDLE_REDUCE_DIM(5, 1);
HANDLE_REDUCE_DIM(4, 3);
HANDLE_REDUCE_DIM(4, 2);
HANDLE_REDUCE_DIM(4, 1);
HANDLE_REDUCE_DIM(3, 2);
HANDLE_REDUCE_DIM(3, 1);
HANDLE_REDUCE_DIM(2, 1);
HANDLE_REDUCE_DIM(1, 1);
}
}
}
//////// Sum Functor ///////
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->sum(dim);
}
};
//////// Mean Functor ///////
struct MeanFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->mean(dim);
}
};
} // namespace eigen
} // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/transform.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/reduce.h"
#include "paddle/pten/kernels/functions/math/cast_func.h"
namespace pten {
namespace general {
template <typename DeviceContext, typename T, typename Functor>
void Reduce(const DeviceContext& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
// If the dims has full dim, set the reduce_all is True
const auto& input_dim_size = x.dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
// no need to cast dtype
if (out_dtype == pten::DataType::UNDEFINED || out_dtype == x.dtype()) {
if (out_dtype == pten::DataType::UNDEFINED) {
out_dtype = x.dtype();
}
// do reduce sum
PD_VISIT_ALL_TYPES(
out_dtype, "ReduceKernelImpl", ([&] {
pten::eigen::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, x, out, dims, keep_dim, reduce_all);
}));
} else {
const auto alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(x.place());
pten::DenseTensor tmp_tensor = pten::DenseTensor(
alloc, pten::DenseTensorMeta(out_dtype, x.dims(), x.layout()));
// cast x tensor to out_dtype first
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
math::CastKernelImpl<DeviceContext, T, data_t>(
dev_ctx, x, &tmp_tensor);
}));
// do reduce sum
PD_VISIT_ALL_TYPES(
out_dtype, "ReduceKernelImpl", ([&] {
pten::eigen::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, tmp_tensor, out, dims, keep_dim, reduce_all);
}));
}
}
} // namespace general
} // namespace pten
add_subdirectory(cpu)
if(WITH_GPU OR WITH_ROCM)
add_subdirectory(cuda)
endif()
cc_library(pten_transpose_cpu SRCS transpose.cc DEPS dense_tensor)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/math/transpose.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
namespace pten {
namespace math {
using CPUContext = paddle::platform::CPUDeviceContext;
template <typename T>
struct TransposeNormal<CPUContext, T> {
// for dims >= 7 situation
void operator()(const CPUContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis) {
const int rank = axis.size();
auto in_stride = paddle::framework::stride(in.dims());
auto out_stride = paddle::framework::stride(out->dims());
const T* in_ptr = in.data<T>();
T* out_ptr = out->mutable_data<T>();
auto transpose_helper = [&](int64_t beg, int64_t end) {
for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
// calculate the input index
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride[i];
tmp_idx -= coordinate * out_stride[i];
in_idx += coordinate * in_stride[axis[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
};
transpose_helper(0, out->numel());
}
};
// define transpose normal
#define DEFINE_CPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<CPUContext, TYPE>
DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint16_t);
DEFINE_CPU_TRANS_NORMAL(int32_t);
DEFINE_CPU_TRANS_NORMAL(uint32_t);
DEFINE_CPU_TRANS_NORMAL(int64_t);
DEFINE_CPU_TRANS_NORMAL(uint64_t);
DEFINE_CPU_TRANS_NORMAL(float);
DEFINE_CPU_TRANS_NORMAL(double);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::float16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::bfloat16);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(paddle::platform::complex<double>);
} // namespace math
} // namespace pten
if(WITH_GPU)
nv_library(pten_transpose_cuda SRCS transpose.cu DEPS dense_tensor malloc)
elseif(WITH_ROCM)
hip_library(pten_transpose_cuda SRCS transpose.cu DEPS dense_tensor malloc)
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/math/cast_func.h"
#include "paddle/pten/kernels/math/transpose.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
namespace pten {
namespace math {
using CUDAContext = paddle::platform::CUDADeviceContext;
#define REINTERPRET(T, DST_PTR, SRC_PTR) \
T* DST_PTR = reinterpret_cast<T*>(SRC_PTR)
template <typename T>
__global__ void TransposeNormalKernel(const T* in_ptr,
T* out_ptr,
int64_t element,
const int64_t* in_stride_ptr,
const int64_t* out_stride_ptr,
const int64_t* axis_ptr,
int rank) {
CUDA_KERNEL_LOOP(out_idx, element) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride_ptr[i];
tmp_idx -= coordinate * out_stride_ptr[i];
in_idx += coordinate * in_stride_ptr[axis_ptr[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
}
template <typename T>
struct TransposeNormal<CUDAContext, T> {
// for dims >= 7 situation
void operator()(const CUDAContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis) {
const int rank = axis.size();
auto in_stride = paddle::framework::stride(in.dims());
auto out_stride = paddle::framework::stride(out->dims());
auto* in_ptr = in.data<T>();
auto* out_ptr = out->mutable_data<T>();
// copy in_stride, out_stride, axis to gpu device
const paddle::platform::CUDAPlace& cuda_place =
BOOST_GET_CONST(paddle::platform::CUDAPlace, dev_ctx.GetPlace());
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = paddle::memory::AllocShared(cpu_place, size);
auto cuda_buf_holder = paddle::memory::AllocShared(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) {
cpu_buf[i] = in_stride[i];
cpu_buf[rank + i] = out_stride[i];
cpu_buf[2 * rank + i] = axis[i];
}
paddle::memory::Copy(
cuda_place, cuda_buf, cpu_place, cpu_buf, size, dev_ctx.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM =
dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
int64_t elements = in.numel();
int block_size = (elements >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(elements)));
int grid_size = elements / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
TransposeNormalKernel<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_ptr,
out_ptr,
elements,
in_stride_ptr,
out_stride_ptr,
axis_ptr,
rank);
}
};
// define transpose normal
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<CUDAContext, TYPE>
DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint16_t);
DEFINE_GPU_TRANS_NORMAL(int32_t);
DEFINE_GPU_TRANS_NORMAL(uint32_t);
DEFINE_GPU_TRANS_NORMAL(int64_t);
DEFINE_GPU_TRANS_NORMAL(uint64_t);
DEFINE_GPU_TRANS_NORMAL(float);
DEFINE_GPU_TRANS_NORMAL(double);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::float16);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::bfloat16);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>);
} // namespace math
} // namespace pten
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace math {
template <typename DeviceContext, typename T>
struct TransposeNormal {
// for dims >= 7 situation
void operator()(const DeviceContext& dev_ctx,
const pten::DenseTensor& in,
pten::DenseTensor* out,
const std::vector<int64_t>& axis);
};
} // namespace math
} // namespace pten
......@@ -19,3 +19,4 @@ cc_test(test_cast_api SRCS test_cast_api.cc DEPS pten_tensor pten_api pten_api_u
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils)
......@@ -46,9 +46,10 @@ TEST(API, mean) {
}
paddle::experimental::Tensor x(dense_x);
std::vector<int64_t> axis = {0, 1};
// 2. test API
auto out = paddle::experimental::mean(x);
auto out = paddle::experimental::mean(x, axis, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/math.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace paddle {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, sum) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 4}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
float sum = 0.0;
for (size_t i = 0; i < 12; ++i) {
dense_x_data[i] = i * 1.0;
sum += i * 1.0;
}
paddle::experimental::Tensor x(dense_x);
std::vector<int64_t> axis = {0, 1};
// 2. test API
auto out = paddle::experimental::sum(x, axis, DataType::UNDEFINED, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
ASSERT_EQ(out.dims()[0], 1);
ASSERT_EQ(out.numel(), 1);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto expect_result = sum;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto actual_result = dense_out->data<float>()[0];
ASSERT_NEAR(expect_result, actual_result, 1e-6f);
}
} // namespace tests
} // namespace paddle
......@@ -7,3 +7,4 @@ cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_cast_dev_api SRCS test_cast_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils)
......@@ -45,9 +45,14 @@ TEST(DEV_API, mean) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> dims = {0, 1};
// 2. test API
auto out = pten::Mean<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), dense_x);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dims,
false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/math.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
TEST(DEV_API, sum) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensor dense_x(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 4}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
float sum = 0.0;
for (size_t i = 0; i < 12; ++i) {
dense_x_data[i] = i * 1.0;
sum += i * 1.0;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> axis = {0, 1};
// 2. test API
auto out = pten::Sum<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
axis,
pten::DataType::FLOAT32,
false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
ASSERT_EQ(out.numel(), 1);
ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW);
auto expect_result = sum;
auto actual_result = out.data<float>()[0];
ASSERT_NEAR(expect_result, actual_result, 1e-6f);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册