未验证 提交 14949521 编写于 作者: L limingshu 提交者: GitHub

Binary functor envoking of elementwise broadcast (#32928)

上级 6f8de31d
...@@ -52,8 +52,9 @@ class AbsKernel<platform::CUDADeviceContext, T> ...@@ -52,8 +52,9 @@ class AbsKernel<platform::CUDADeviceContext, T>
std::vector<const framework::Tensor*> ins = {x}; std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out}; std::vector<framework::Tensor*> outs = {out};
auto functor = CudaAbsFunctor<T>(); auto functor = CudaAbsFunctor<T>();
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, math::Real<T>>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T,
dev_ctx, ins, &outs, functor); math::Real<T>>(dev_ctx, ins, &outs,
functor);
} }
}; };
......
...@@ -1316,8 +1316,8 @@ class ActivationCudaKernel ...@@ -1316,8 +1316,8 @@ class ActivationCudaKernel
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first); *attr.second = ctx.Attr<float>(attr.first);
} }
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(dev_ctx, ins, LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
&outs, functor); dev_ctx, ins, &outs, functor);
} }
}; };
...@@ -1346,16 +1346,16 @@ class ActivationGradCudaKernel ...@@ -1346,16 +1346,16 @@ class ActivationGradCudaKernel
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) { if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
// Only need forward output Out // Only need forward output Out
ins.push_back(out); ins.push_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor); dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) == } else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) { static_cast<int>(kDepX)) {
// Only need forward input X // Only need forward input X
ins.push_back(x); ins.push_back(x);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor); dev_ctx, ins, &outs, functor);
} else { } else {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor); dev_ctx, ins, &outs, functor);
} }
} }
......
...@@ -69,15 +69,6 @@ struct SameDimsElemwiseAdd< ...@@ -69,15 +69,6 @@ struct SameDimsElemwiseAdd<
} }
}; };
template <typename T>
struct BroadcastElemwiseAdd<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
default_elementwise_add<platform::CPUDeviceContext, T>(ctx, x, y, z);
}
};
class ElementwiseAddOpMaker : public ElementwiseOpMaker { class ElementwiseAddOpMaker : public ElementwiseOpMaker {
protected: protected:
std::string GetName() const override { return "Add"; } std::string GetName() const override { return "Add"; }
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -40,29 +39,24 @@ struct CudaAddFunctor { ...@@ -40,29 +39,24 @@ struct CudaAddFunctor {
}; };
template <typename T> template <typename T>
struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> { class ElementwiseAddKernel<platform::CUDADeviceContext, T>
void operator()(const framework::ExecutionContext& ctx, : public framework::OpKernel<T> {
const framework::Tensor* x, const framework::Tensor* y, public:
framework::Tensor* z) { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
std::vector<const framework::Tensor*> ins = {x, y}; std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {z}; std::vector<framework::Tensor*> outs = {z};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>(), ins, &outs, ctx.template device_context<platform::CUDADeviceContext>();
CudaAddFunctor<T>());
}
};
template <typename T> LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
struct BroadcastElemwiseAdd<platform::CUDADeviceContext, T> { cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* out) {
std::vector<const framework::Tensor*> ins = {x, y};
int axis = ctx.Attr<int>("axis");
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, T>(
ctx.template device_context<platform::CUDADeviceContext>(), ins, out,
CudaAddFunctor<T>(), axis);
} }
}; };
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#endif #endif
#ifdef __HIPCC__ #ifdef __HIPCC__
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
...@@ -40,9 +40,10 @@ namespace paddle { ...@@ -40,9 +40,10 @@ namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext &ctx, void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) { const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
...@@ -62,13 +63,6 @@ struct SameDimsElemwiseAdd { ...@@ -62,13 +63,6 @@ struct SameDimsElemwiseAdd {
framework::Tensor *z); framework::Tensor *z);
}; };
template <typename DeviceContext, typename T, class Enable = void>
struct BroadcastElemwiseAdd {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
...@@ -77,13 +71,13 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -77,13 +71,13 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto *y = ctx.Input<framework::LoDTensor>("Y"); auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out"); auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims(); if (x->dims() == y->dims()) {
if (dims_equal) { SameDimsElemwiseAdd<platform::CPUDeviceContext, T>
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add; LaunchElementwiseCpuKernel;
same_dims_add(ctx, x, y, z); LaunchElementwiseCpuKernel(ctx, x, y, z);
} else { } else {
BroadcastElemwiseAdd<DeviceContext, T> broadcast_add; LaunchBroadcastElementwiseCpuKernel<platform::CPUDeviceContext, T>(ctx, x,
broadcast_add(ctx, x, y, z); y, z);
} }
} }
}; };
...@@ -469,8 +463,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> { ...@@ -469,8 +463,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe); GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace()); ddout->mutable_data<T>(ctx.GetPlace());
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe, LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, &ddx_safe,
ddout); &ddy_safe, ddout);
} }
} }
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,7 +28,8 @@ struct DimensionsTransform { ...@@ -28,7 +28,8 @@ struct DimensionsTransform {
std::vector<DimVector> in_dims; std::vector<DimVector> in_dims;
private: private:
// 1. To compensate the lackage of input_tensors` dimension; // To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
void InputDimensionsExtend(int N, int axis) { void InputDimensionsExtend(int N, int axis) {
for (auto &in_dim : in_dims) { for (auto &in_dim : in_dims) {
int64_t in_idx = 0; int64_t in_idx = 0;
...@@ -70,7 +71,7 @@ struct DimensionsTransform { ...@@ -70,7 +71,7 @@ struct DimensionsTransform {
} }
template <typename MergeFunctor> template <typename MergeFunctor>
__inline__ void DimensionsReorganise(MergeFunctor merge_func, int N) { __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
(*vec)[m_idx - 1] = (*vec)[m_idx - 1] =
std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1, std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
...@@ -139,7 +140,7 @@ struct DimensionsTransform { ...@@ -139,7 +140,7 @@ struct DimensionsTransform {
// To Merge the dimensions of input_tensors while the consequtive // To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears. // equal-dimensions appears.
MergeFunctor merge_ptr = merge_sequential_dims; MergeFunctor merge_ptr = merge_sequential_dims;
DimensionsReorganise<MergeFunctor>(merge_ptr, N); MergeDimensions<MergeFunctor>(merge_ptr, N);
int min_idx = 0; int min_idx = 0;
int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1, int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
...@@ -155,12 +156,12 @@ struct DimensionsTransform { ...@@ -155,12 +156,12 @@ struct DimensionsTransform {
// To Merge the dimension of input_tensors while the consequtive // To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears. // 1-value-dimensions appears.
merge_ptr = merge_sequential_one_dims; merge_ptr = merge_sequential_one_dims;
DimensionsReorganise<MergeFunctor>(merge_ptr, N); MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[min_idx], in_dims[0]); std::swap(in_dims[min_idx], in_dims[0]);
} }
}; };
struct CalculateInputStrides { struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides; std::vector<std::vector<uint32_t>> strides;
std::vector<FastDivMod> divmoders; std::vector<FastDivMod> divmoders;
...@@ -181,9 +182,9 @@ struct CalculateInputStrides { ...@@ -181,9 +182,9 @@ struct CalculateInputStrides {
} }
public: public:
explicit CalculateInputStrides( explicit StridesCalculation(const int64_t &dim_size,
const int64_t &dim_size, const std::vector<std::vector<int64_t>> &in_dims, const std::vector<std::vector<int64_t>> &in_dims,
const std::vector<int64_t> &out_dims) { const std::vector<int64_t> &out_dims) {
const auto N = in_dims.size(); const auto N = in_dims.size();
divmoders.resize(dim_size); divmoders.resize(dim_size);
strides.resize(N, std::vector<uint32_t>(dim_size, 1)); strides.resize(N, std::vector<uint32_t>(dim_size, 1));
...@@ -195,34 +196,40 @@ struct CalculateInputStrides { ...@@ -195,34 +196,40 @@ struct CalculateInputStrides {
} }
}; };
template <typename T, ElementwiseType ET, int VecSize, int kDims> template <typename T, typename Functor, ElementwiseType ET, int VecSize,
int kDims>
struct BroadcastArgsWarpper { struct BroadcastArgsWarpper {
using DimsVec = CudaAlignedVector<T, VecSize>; using VecType = CudaAlignedVector<T, VecSize>;
T *out_data; T *out_data;
VecType *vec_out_data;
const T *__restrict__ in_data[ET]; const T *__restrict__ in_data[ET];
uint32_t strides[ET][framework::DDim::kMaxRank]; const VecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET]; bool no_broadcast[ET];
FastDivMod divmoders[kDims]; FastDivMod divmoders[kDims];
uint32_t scalar_offset; uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;
HOSTDEVICE BroadcastArgsWarpper( HOSTDEVICE BroadcastArgsWarpper(
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
const CalculateInputStrides &offset_calculator, framework::Tensor *out, int scalar_cal_offset, Functor func,
int scalar_offset) const StridesCalculation &offset_calculator)
: scalar_offset(scalar_offset) { : scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<T>(); in_data[j] = ins[j]->data<T>();
vec_in_data[j] = reinterpret_cast<const VecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(), memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t)); kDims * sizeof(uint32_t));
} }
out_data = out->data<T>(); out_data = out->data<T>();
vec_out_data = reinterpret_cast<VecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(), memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod)); kDims * sizeof(FastDivMod));
} }
__device__ __forceinline__ uint32_t GetDivmodOffset(int idx, int in_idx) { __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
uint32_t offset = 0; uint32_t offset = 0;
#pragma unroll(kDims) #pragma unroll(kDims)
...@@ -234,120 +241,127 @@ struct BroadcastArgsWarpper { ...@@ -234,120 +241,127 @@ struct BroadcastArgsWarpper {
return offset; return offset;
} }
__device__ __forceinline__ void CommonVector(DimsVec args[], int tid, __device__ __forceinline__ void LoadVectorizedDataCommon(VecType *vector_args,
int idx) { int tid, int idx) {
const DimsVec *__restrict__ vec_data = *vector_args = vec_in_data[idx][tid];
reinterpret_cast<const DimsVec *__restrict__>(in_data[idx]);
args[idx] = vec_data[tid];
} }
__device__ __forceinline__ void DivmodVector(DimsVec args[], int tid, __device__ __forceinline__ void LoadVectorizedDataByDivmod(T *scalar_args,
int idx) { int tid, int idx) {
int index = tid * VecSize; int index = tid * VecSize;
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
uint32_t offset = GetDivmodOffset(index + i, idx); uint32_t offset = GetOffsetByDivmod(index + i, idx);
args[idx].val[i] = in_data[idx][offset]; scalar_args[i] = in_data[idx][offset];
} }
} }
__device__ __forceinline__ void CommonScalar(T args[], int tid, int idx) { __device__ __forceinline__ void LoadScalarizedDataCommon(T args[], int tid,
args[idx] = in_data[idx][tid + scalar_offset]; int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset];
} }
__device__ __forceinline__ void DivmodScalar(T args[], int tid, int idx) { __device__ __forceinline__ void LoadScalarizedDataByDivmod(T args[], int tid,
auto offset = GetDivmodOffset(tid + scalar_offset, idx); int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset]; args[idx] = in_data[idx][offset];
} }
__device__ __forceinline__ void LoadVector(DimsVec args[], int tid) { __device__ __forceinline__ void LoadVectorizedData(T (*args)[VecSize],
int tid) {
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) { if (no_broadcast[j]) {
CommonVector(args, tid, j); VecType *vector_args = reinterpret_cast<VecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j);
} else { } else {
DivmodVector(args, tid, j); LoadVectorizedDataByDivmod(args[j], tid, j);
} }
} }
} }
__device__ __forceinline__ void LoadScalar(T args[], int tid) { __device__ __forceinline__ void LoadScalarizedData(T args[], int tid) {
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) { if (no_broadcast[j]) {
CommonScalar(args, tid, j); LoadScalarizedDataCommon(args, tid, j);
} else { } else {
DivmodScalar(args, tid, j); LoadScalarizedDataByDivmod(args, tid, j);
} }
} }
} }
__device__ __forceinline__ void StoreVector(DimsVec args[], int tid) { __device__ __forceinline__ void StoreVectorizedData(T (*args)[VecSize],
DimsVec *vec_out = reinterpret_cast<DimsVec *>(out_data); int tid) {
vec_out[tid] = args[0]; VecType *args_out = reinterpret_cast<VecType *>(args[0]);
vec_out_data[tid] = *args_out;
} }
__device__ __forceinline__ void StoreScalar(T args[], int tid) { __device__ __forceinline__ void StoreScalarizedData(T args[], int tid) {
out_data[scalar_offset + tid] = args[0]; out_data[scalar_cal_offset + tid] = args[0];
} }
}; };
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET> template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl( __device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper data_transfer, int tid) { BroadcastArgsWarpper broadcast_warpper, int tid) {
T args[ET]; T args[ET];
data_transfer.LoadScalar(args, tid); broadcast_warpper.LoadScalarizedData(args, tid);
#pragma unroll(ET) #pragma unroll(ET)
for (int j = 1; j < ET; ++j) { for (int j = 1; j < ET; ++j) {
args[0] += args[j]; args[0] = broadcast_warpper.func(args);
} }
data_transfer.StoreScalar(args, tid); broadcast_warpper.StoreScalarizedData(args, tid);
} }
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET, template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
int VecSize> int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl( __device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper data_transfer, int tid) { BroadcastArgsWarpper broadcast_warpper, int tid) {
using VecT = CudaAlignedVector<T, VecSize>; T ins[ET];
VecT args[ET]; T args[ET][VecSize];
data_transfer.LoadVector(args, tid); broadcast_warpper.LoadVectorizedData(args, tid);
#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
#pragma unroll(VecSize) #pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
args[0].val[i] += args[j].val[i]; #pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
} }
args[0][i] = broadcast_warpper.func(ins);
} }
data_transfer.StoreVector(args, tid); broadcast_warpper.StoreVectorizedData(args, tid);
} }
template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET, template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
int VecSize> int VecSize>
__global__ void ElementwiseBroadcastKernel(BroadcastArgsWarpper data_transfer, __global__ void ElementwiseBroadcastKernel(
int main_tid, int tail_tid) { BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Aimming at vectorized calculation of major data whose length is max // Vectorized calculation of major data whose length is the max multipler of
// multipler of VecSize. // VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) { if (tid < main_tid) {
VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>( VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>(
data_transfer, tid); broadcast_warpper, tid);
} }
// Aimming at scalar calculation of rest data whose lenght cannot fulfill // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// VecSize. // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) { if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>(data_transfer, ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>(
tid); broadcast_warpper, tid);
} }
} }
template <typename T, ElementwiseType ET, int VecSize = 1> template <typename T, ElementwiseType ET, int VecSize, typename Functor>
void LaunchBroadcastKernelForDifferentDimSize( void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out, const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis) { int axis, Functor func) {
int numel = out->numel(); int numel = out->numel();
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
...@@ -357,72 +371,72 @@ void LaunchBroadcastKernelForDifferentDimSize( ...@@ -357,72 +371,72 @@ void LaunchBroadcastKernelForDifferentDimSize(
auto stream = ctx.stream(); auto stream = ctx.stream();
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
const auto offset_calculator = CalculateInputStrides( const auto offset_calculator = StridesCalculation(
merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims); merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);
switch (merge_dims.dim_size) { switch (merge_dims.dim_size) {
case 1: { case 1: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 1>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 1>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 2: { case 2: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 2>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 2>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 3: { case 3: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 3>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 3>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 4: { case 4: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 4>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 4>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 5: { case 5: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 5>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 5>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 6: { case 6: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 6>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 6>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 7: { case 7: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 7>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 7>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
case 8: { case 8: {
auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 8>( auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 8>(
ins, offset_calculator, out, vec_len); ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<T, decltype(data_transfer), ET, ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>( VecSize><<<blocks, threads, 0, stream>>>(
data_transfer, main_tid, tail_tid); broadcast_warpper, main_tid, tail_tid);
break; break;
} }
default: { default: {
...@@ -437,9 +451,11 @@ void LaunchBroadcastKernelForDifferentDimSize( ...@@ -437,9 +451,11 @@ void LaunchBroadcastKernelForDifferentDimSize(
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename T, typename Functor>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out, const std::vector<const framework::Tensor *> &ins,
Functor func, int axis) { std::vector<framework::Tensor *> *outs, int axis, Functor func) {
static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
int in_vec_size = 4; int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) { for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>()); auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
...@@ -450,19 +466,46 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -450,19 +466,46 @@ void LaunchBroadcastElementwiseCudaKernel(
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis); LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis,
func);
break; break;
} }
case 2: { case 2: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis); LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis,
func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis,
func);
break; break;
} }
default: { default: {
LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break; break;
} }
} }
} }
template <ElementwiseType ET, typename InT, typename OutType, typename Functor>
void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
}
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
cuda_ctx, ins, outs, func);
} else {
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT>(
cuda_ctx, ins, outs, axis, func);
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,8 +15,7 @@ limitations under the License. */ ...@@ -15,8 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256 #define ELEMENTWISE_BLOCK_SIZE 256
...@@ -29,11 +28,6 @@ namespace operators { ...@@ -29,11 +28,6 @@ namespace operators {
enum ElementwiseType { kUnary = 1, kBinary = 2 }; enum ElementwiseType { kUnary = 1, kBinary = 2 };
template <typename T, int Size>
struct alignas(sizeof(T) * Size) CudaAlignedVector {
T val[Size];
};
template <typename T> template <typename T>
int GetVectorizedSizeImpl(const T *pointer) { int GetVectorizedSizeImpl(const T *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer); uint64_t address = reinterpret_cast<uint64_t>(pointer);
...@@ -181,7 +175,7 @@ __global__ void ScalarKernel(const InT *__restrict__ in0, ...@@ -181,7 +175,7 @@ __global__ void ScalarKernel(const InT *__restrict__ in0,
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel( void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
......
...@@ -14,13 +14,19 @@ limitations under the License. */ ...@@ -14,13 +14,19 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
#define INT_BITS 32 #define INT_BITS 32
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, int Size>
struct alignas(sizeof(T) * Size) CudaAlignedVector {
T val[Size];
};
struct FastDivMod { struct FastDivMod {
// 1st value represents the result of input number divides by recorded divisor // 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor // 2nd value represents the result of input number modulo by recorded divisor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册