未验证 提交 40d4a295 编写于 作者: Y Yiqun Liu 提交者: GitHub

Revert "Implement FunctionTraits to support two kinds of elementwise functor...

Revert "Implement FunctionTraits to support two kinds of elementwise functor and remove some old codes for broadcast. (#35487)" (#35686)
上级 5f31737b
......@@ -24,15 +24,15 @@ struct CudaAbsFunctor;
template <typename T>
struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> {
__device__ __forceinline__ math::Real<T> operator()(const T& x) const {
return abs(x);
__device__ __forceinline__ math::Real<T> operator()(const T* args) const {
return abs(args[0]);
}
};
template <typename T>
struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T& x) const {
return std::abs(x);
__device__ __forceinline__ T operator()(const T* args) const {
return std::abs(args[0]);
}
};
......
......@@ -18,46 +18,60 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Functor>
class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using T = typename Functor::ELEM_TYPE;
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##CUDAFunctor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T* args) const { \
return args[0] expr args[1]; \
} \
}; \
\
template <> \
struct Bitwise##func##CUDAFunctor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool* args) const { \
return args[0] bool_expr args[1]; \
} \
};
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
auto functor = Functor();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {out};
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, -1, functor);
}
template <typename T>
struct BitwiseNotCUDAFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T* args) const { return ~args[0]; }
};
template <>
struct BitwiseNotCUDAFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool* args) const { return !args[0]; }
};
template <typename Functor>
class UnaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
using T = typename Functor::ELEM_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto functor = Functor();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
cuda_ctx, ins, &outs, functor);
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
if (ins.size() == 1) {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
}
}
};
......@@ -67,7 +81,7 @@ class UnaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform;
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorFunctor);
REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotCUDAFunctor);
......@@ -17,6 +17,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
......@@ -35,6 +38,23 @@ struct BitwiseAdd {
}
};
template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};
template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
......@@ -77,9 +97,6 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
......@@ -92,5 +109,5 @@ namespace plat = paddle::platform;
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, EqualReduceFunctor)
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
......@@ -21,11 +21,46 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
template <typename T, typename Enable = void> \
struct func { \
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
template <typename T>
struct CudaEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename T>
struct CudaNotEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) > 1e-8;
}
};
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using InT = typename Functor::ELEM_TYPE;
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor();
......@@ -52,10 +87,10 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL(equal, EqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, NotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, LessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, LessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, GreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, GreaterEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
......@@ -25,6 +24,21 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
/*
input: an array;
return: the result of the math functor
1. For Unary Op, the length of input array is 1,
e.g. Relu: return args[0] > 0 ? args[0] : 0;
2. For Binary Op, the length of input array is 2,
e.g. Add: return args[0] expr args[1];
*/
template <typename T>
struct CudaAddFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
template <typename T>
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -37,7 +51,7 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, AddFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
}
};
......
......@@ -24,6 +24,13 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CudaMulFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] * args[1];
}
};
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -37,7 +44,7 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
}
};
......
......@@ -16,10 +16,11 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
namespace paddle {
namespace operators {
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig
namespace kps = paddle::operators::kernel_primitives;
struct DimensionsTransform {
......@@ -45,9 +46,10 @@ struct DimensionsTransform {
axis++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
"The %dth dimension of input tensor is expected to be equal "
"with"
"the %dth dimension of output tensor %d or 1, but recieved "
"%d.\n",
in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
}
} while (in_idx < in_dim.size());
......@@ -59,9 +61,10 @@ struct DimensionsTransform {
in_idx++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
"The %dth dimension of input tensor is expected to be equal "
"with"
"the %dth dimension of output tensor %d or 1, but recieved "
"%d.\n",
in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
}
} while (in_idx < dim_size);
......@@ -162,71 +165,79 @@ struct DimensionsTransform {
}
};
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst, const T *__restrict__ src, uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config, int numel, int num,
const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num,
bool need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
config, numel, 1, 1);
kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>(
dst, src, block_offset, config, numel, 1, 1);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank, bool IsBoundary = false>
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor, bool IsBoundary = false>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, Arity> &ins, OutT *out,
const framework::Array<bool, Arity> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
MAX_INPUT_NUM> &configlists,
int num, Functor func) {
InT args[Arity][VecSize];
InT args[ET][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll
for (int i = 0; i < Arity; i++) {
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i], ins[i], block_offset,
configs[i], numel, num,
use_broadcast[i]);
LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
configlists[i], numel, num,
use_broadcast[i]);
}
const bool kCallElementwiseAny =
platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
kCallElementwiseAny>()(func, args, result);
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
// compute
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank>
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor>
__global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, Arity> ins, OutT *out,
framework::Array<bool, Arity> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
framework::Array<const InT *__restrict__, ET> in, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
ins, out, use_broadcast, numel, configs, num, func);
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
in, out, use_broadcast, numel, configlists, num, func);
} else { // reminder
int num = tail_tid;
DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
ins, out, use_broadcast, numel, configs, num, func);
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
in, out, use_broadcast, numel, configlists, num, func);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank>
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
int Size, typename Functor>
void LaunchKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
framework::Tensor *out, Functor func,
......@@ -240,58 +251,53 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx,
auto stream = ctx.stream();
OutT *out_data = out->data<OutT>();
framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
framework::Array<bool, Arity> use_broadcast;
framework::Array<const InT *__restrict__, Arity> ins_data;
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
configlists;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
framework::Array<const InT *__restrict__, ET> ins_data;
for (int i = 0; i < Arity; i++) {
for (int i = 0; i < ET; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
configlists[i] = kps::details::BroadcastConfig<Size>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
BroadcastKernel<InT, OutT, Functor, Arity, VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid,
BroadcastKernel<ET, InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void LaunchBroadcastKernelForDifferentVecSize(
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
typename Functor>
void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>(ctx, ins, out, \
func, merge_dims); \
#define DIM_SIZE(size) \
case size: { \
LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
merge_dims); \
} break;
switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
DIM_SIZE(1);
DIM_SIZE(2);
DIM_SIZE(3);
DIM_SIZE(4);
DIM_SIZE(5);
DIM_SIZE(6);
DIM_SIZE(7);
DIM_SIZE(8);
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
#undef DIM_SIZE
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......@@ -299,21 +305,11 @@ void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(), kArity));
PADDLE_ENFORCE_EQ(kArity, 2,
platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
kArity));
"Currently, only Support binary calculation, "
"but received %d input tensors.\n",
static_cast<int>(ET)));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
......@@ -326,18 +322,18 @@ void LaunchBroadcastElementwiseCudaKernel(
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
ctx, ins, out, axis, func);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
axis, func);
break;
}
default: {
......@@ -373,5 +369,7 @@ void LaunchElementwiseCudaKernel(
}
}
#undef MAX_INPUT_NUM
} // namespace operators
} // namespace paddle
......@@ -37,10 +37,8 @@ limitations under the License. */
#endif
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
#else
......@@ -280,6 +278,128 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *__restrict__ x_data,
const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
}
}
template <typename Functor, typename T, typename OutType>
void ComputeElementwiseCUDA(const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
int pre, int n, int post,
const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, n, post, numel, func);
} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, n, post, numel, func);
}
}
template <typename Functor, typename T, typename OutType = T>
__global__ void CommonForwardBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const T *x, const T *y, OutType *out,
int out_size, int max_dim, Functor func, const bool is_xsize_larger) {
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x;
out_index < out_size; out_index += blockDim.x * gridDim.x) {
int x_index = 0;
int y_index = 0;
int out_index_quotient = out_index;
int remainder = 0;
#pragma unroll
for (int i = max_dim - 1; i >= 0; --i) {
GetDivMod(out_index_quotient, out_dims_array[i], &out_index_quotient,
&remainder);
x_index += remainder * x_strides_array[i];
y_index += remainder * y_strides_array[i];
}
if (is_xsize_larger) {
out[out_index] = func(x[x_index], y[y_index]);
} else {
out[out_index] = func(y[y_index], x[x_index]);
}
}
}
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCUDA(
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
int *out_dims_array, int max_dim, const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto cplace = platform::CPUPlace();
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
}
int bytes = max_dim * sizeof(int);
auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
bytes, ctx.stream());
auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
bytes, ctx.stream());
auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
ctx.stream());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
dim3 gird_size = dim3(
(out_size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
CommonForwardBroadcastCUDAKernel<
Functor, T, OutType><<<gird_size, block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
}
#endif // __NVCC__ or __HIPCC__
template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCPU(
const framework::Tensor &x, const framework::Tensor &y,
......@@ -1797,10 +1917,21 @@ void CommonElementwiseBroadcastForward(
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
CommonForwardBroadcastCPU<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(),
max_dim, ctx.template device_context<platform::CPUDeviceContext>(), func,
is_xsize_larger);
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
CommonForwardBroadcastCUDA<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
} else {
CommonForwardBroadcastCPU<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), func,
is_xsize_larger);
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
......@@ -1844,35 +1975,12 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
}
}
// It is a common implementation to compute binary calculation with the support
// of broadcast, supporting both CPU and GPU.
// - CPU implementation cannot support the case when x needs broadcast, thus
// this function need to be called with XxxFunctor and XxxInverseFunctor,
// like paddle/fluid/operators/elementwise/elementwise_add_op.h#L49 - L55.
// - GPU implementation supports all the broadcast cases, thus there is no need
// to define and call with XxxInverseFunctor.
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor *> ins = {x, y};
std::vector<framework::Tensor *> outs = {z};
z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, OutType>(
dev_ctx, ins, &outs, axis, func);
#endif
return;
}
auto x_dims = x->dims();
auto y_dims = y->dims();
bool is_xsize_larger = true;
......@@ -1921,6 +2029,15 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
return;
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
ComputeElementwiseCUDA<Functor, T, OutType>(
x, y, z, pre, n, post,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
return;
}
if (post == 1) {
functor.RunRowWise(n, pre);
return;
......
......@@ -11,13 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
......@@ -29,8 +28,7 @@ namespace paddle {
namespace operators {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
......@@ -57,9 +55,8 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
}
template <typename InT, typename OutT>
int GetVectorizedSizeForTensors(
const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size = std::min<int>(vec_size,
......@@ -72,88 +69,56 @@ int GetVectorizedSizeForTensors(
return vec_size;
}
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(result, args,
func);
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor, bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
InT args[ET][VecSize];
OutT result[VecSize];
// load data
#pragma unroll
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline OutT operator()(Functor func, InT **args, OutT *result) {
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, Arity> &in, OutT *out,
int num, Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
const bool kCallElementwiseAny =
platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
kCallElementwiseAny>()(func, args, result);
// store
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__global__ void ElementVectorizeKernel(
framework::Array<const InT *__restrict__, Arity> ins, OutT *out, int size,
framework::Array<const InT *__restrict__, ET> in, OutT *out, int size,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int VecSize>
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs,
......@@ -164,14 +129,14 @@ void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out_data = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>();
OutT *out = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, ET> in;
for (int i = 0; i < ET; i++) {
in[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<InT, OutT, Functor, Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, func);
ElementVectorizeKernel<ET, VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in, out, numel, func);
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......@@ -179,30 +144,17 @@ void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(), kArity));
// calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs);
switch (vec_size) {
case 4:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(ctx, ins, outs,
func);
ElementwiseCudaKernel<ET, InT, OutT, Functor, 4>(ctx, ins, outs, func);
break;
case 2:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(ctx, ins, outs,
func);
ElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(ctx, ins, outs, func);
break;
case 1:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(ctx, ins, outs,
func);
ElementwiseCudaKernel<ET, InT, OutT, Functor, 1>(ctx, ins, outs, func);
break;
default: {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -22,6 +22,13 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CudaSubFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] - args[1];
}
};
template <typename T>
class ElementwiseSubKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -34,7 +41,7 @@ class ElementwiseSubKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, SubFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaSubFunctor<T>());
}
};
......
......@@ -52,8 +52,10 @@ template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
struct CudaAddFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
template <typename InT, typename OutT, int ShapeSize, int VecSize,
......@@ -126,7 +128,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
auto func = AddFunctor<T>();
auto func = CudaAddFunctor<T>();
auto stream = ctx.stream();
switch (vec_size) {
case 4: {
......
......@@ -21,6 +21,7 @@
#include <hip/hip_fp16.h>
#endif
// #include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
......@@ -134,114 +135,53 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
} // namespace details
/**
* @brief unary function
* @param
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a) const {
* return ...;
* }
* };
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute(in[idx]));
}
}
/*************************** Compute Function****************************/
/**
* @brief binary function, in1 and in2 have same shape
* @param
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b) const {
* return ...;
* }
* };
* OpFunc: compute functor eg: in1 + in2, in1 - in2
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
const T* in2,
OpFunc compute) {
T args[2];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
args[0] = in1[idx];
args[1] = in2[idx];
out[idx] = static_cast<OutT>(compute(args));
}
}
/**
* @brief ternary function, in1, in2 and in3 have same shape
* @param
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
* return ...;
* }
* };
* OpFunc: compute functor eg: out = in1 * in2 + in3
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
const T* in2, const T* in3,
OpFunc compute) {
T args[3];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
}
}
/**
* @brief a general function for elementwise computation, all inputs have
* the same shape.
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T* args) const {
* return ...;
* }
* };
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize, int Arity,
class OpFunc>
__device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
OpFunc compute) {
T args[Arity];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
for (int j = 0; j < Arity; ++j) {
args[j] = ins[j][idx];
}
args[0] = in1[idx];
args[1] = in2[idx];
args[2] = in3[idx];
out[idx] = static_cast<OutT>(compute(args));
}
}
......@@ -249,7 +189,7 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
/**
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX]
* @param
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
......@@ -271,6 +211,26 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
}
}
/**
* @brief unary function
* @param:
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor eg: relu, exp
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute(in + idx));
}
}
/**
* @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
......@@ -278,7 +238,7 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called
* @typename
* @typename
* T: data type of in
* NX: the cols of in
* NY: the rows of in
......
......@@ -15,14 +15,18 @@
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/lgamma_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaLgammaFunctor;
template <typename T>
struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T& x) const {
return Eigen::numext::lgamma(x);
struct CudaLgammaFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T* args) const {
return Eigen::numext::lgamma(args[0]);
}
};
......@@ -33,14 +37,15 @@ class LgammaKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
out->mutable_data<math::Real<T>>(context.GetPlace());
auto& dev_ctx = context.device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaLgammaFunctor<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T,
math::Real<T>>(dev_ctx, ins, &outs,
functor);
}
};
......
......@@ -129,10 +129,17 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext, T,
int64_t>(context, &eigenvalue_tensor, &tol_tensor,
axis, GreaterThanFunctor<T>(),
&compare_result);
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext,
T, int64_t>(context, &eigenvalue_tensor, &tol_tensor,
axis, GreaterThanFunctor<T>(),
&compare_result);
} else {
ElementwiseComputeEx<LessThanFunctor<T>, platform::CUDADeviceContext, T,
int64_t>(context, &eigenvalue_tensor, &tol_tensor,
axis, LessThanFunctor<T>(),
&compare_result);
}
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.1
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <tuple>
namespace paddle {
namespace platform {
// Declare a template class with a single template parameter.
template <typename>
struct FunctionTraits;
// A forwarding trait allowing functors (objects which have an operator())
// to be used with this traits class.
template <typename T>
struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
// A partial specialization of FunctionTraits for pointers to member functions.
template <typename ClassType, typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (ClassType::*)(Args...) const> {
static const size_t arity = sizeof...(Args);
static const bool has_pointer_args =
(arity == 1) &&
(std::is_pointer<
typename std::tuple_element<0, std::tuple<Args...>>::type>::value);
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册