未验证 提交 f9ea2301 编写于 作者: B Bo Zhang 提交者: GitHub

[cherry-pick 2.5] Broadcast && Dropout_nd Performance Optimization into Release/2.5 (#53623)

* Support different dtypes of inputs for broadcast for dropout optimization  (#52093)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* PR comment

* dropout_nd_optimization (#51479)

* with printf

* add DropOutNdForwardKernel

* PR comment

* Dropout optimize & clean broadcast inT and ElementwiseType (#52969)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* clean ElementwiseT and InT for BroadcastKernel

* default axis and clean inT

* remove redundant fast divmod computation

* optimize drop_nd & drop_nd_grad

* optimize BroadcastDataLoader bf16 fp16

* rm InT etc. after merge develop

* delete constexpr for windows ci

* fix conflict

* fix conflic with develop

* fix conflic

* new clean

* clean

* Fix xpu2 kp compile error (#53548)

* fix conflict

* conflict
上级 fecea4c5
......@@ -19,17 +19,13 @@
namespace paddle {
namespace operators {
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
template <typename OutT, typename Functor, int NumOuts = 1>
void LaunchElementwiseCudaKernel(
const KPDevice &ctx,
const std::vector<const phi::DenseTensor *> &ins,
std::vector<phi::DenseTensor *> *outs,
int axis,
Functor func) {
Functor func,
int axis = -1) {
std::vector<const phi::DenseTensor *> pt_inputs;
std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
......@@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
phi::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
phi::funcs::BroadcastKernel<OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, func, axis);
}
} // namespace operators
......
......@@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::ElementwiseCompute<Functor, T, OutType>(
dev_ctx, *x, *y, axis, func, z);
dev_ctx, *x, *y, func, z, axis);
}
// FusedElemwiseAndAct
......@@ -1596,7 +1596,7 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,
#if defined(__NVCC__) || defined(__HIPCC__)
template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place,
int axis,
......@@ -1605,11 +1605,11 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
phi::DenseTensor *dx,
phi::DenseTensor *dy,
Functor func) {
phi::GetGradXAndYOut<ET, T, Functor>(
phi::GetGradXAndYOut<T, Functor>(
dev_ctx, place, axis, ins, *dout, dx, dy, func);
}
template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place,
int axis,
......@@ -1617,8 +1617,7 @@ void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
const phi::DenseTensor *dout,
phi::DenseTensor *dxy,
Functor func) {
phi::GetGradXOrYOut<ET, T, Functor>(
dev_ctx, place, axis, ins, *dout, dxy, func);
phi::GetGradXOrYOut<T, Functor>(dev_ctx, place, axis, ins, *dout, dxy, func);
}
#endif
......
......@@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using ElementwiseType = phi::ElementwiseType;
template <typename OutT, typename Functor, int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel(
const KPDevice &ctx,
......
......@@ -109,8 +109,8 @@ class AttnMatMul {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
}
}
......
......@@ -85,8 +85,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
......@@ -139,8 +139,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
......
......@@ -255,12 +255,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::AddFunctor<T>(),
elewise_add_axis);
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......@@ -432,12 +431,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::AddFunctor<T>(),
elewise_add_axis);
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......
......@@ -689,13 +689,13 @@ class FMHAGateRef {
std::vector<const phi::DenseTensor*> ins = {
qk_out, src_mask, nonbatched_bias};
std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, TernaryAddFunctor<T>());
} else {
std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask};
std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
}
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
}
......
......@@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
ins.emplace_back(attn);
ins.emplace_back(mask);
outs.emplace_back(&attn_tmp);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, AttnMaskFunctor<T>());
LaunchElementwiseCudaKernel<T>(dev_ctx, ins, &outs, AttnMaskFunctor<T>());
// 2. Reduce sum
const std::vector<int64_t> reduce_dims{1, 2};
......
......@@ -836,8 +836,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx,
phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
pt_d_out.get(),
pt_d_x.get(),
pt_out_dtype,
......
......@@ -31,8 +31,8 @@ namespace phi {
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T, T>( \
dev_ctx, x, y, -1, func, out); \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
dev_ctx, x, y, func, out); \
}
DEFINE_BITWISE_KERNEL(And)
......
......@@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx,
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, axis, Functor(), out);
ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, axis, InverseFunctor(), out);
ctx, x, y, InverseFunctor(), out, axis);
}
}
......@@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx,
tmp_data[0] = Functor()(x.data<T>()[0], y.data<T>()[0]);
} else {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, 0, Functor(), &tmp);
ctx, x, y, Functor(), &tmp, 0);
}
auto tmp_flat = EigenVector<bool>::Flatten(tmp);
auto out_es = EigenScalar<bool>::From(*out);
......
......@@ -91,8 +91,8 @@ struct DirichletSampler<CPUContext, T> {
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
}
};
......
......@@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
dev_ctx, x, y, funcs::DivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, axis);
}
}
}
......
......@@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MaximumFunctor<T>(), out);
dev_ctx, x, y, funcs::MaximumFunctor<T>(), out, axis);
}
template <typename T, typename Context>
......@@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MinimumFunctor<T>(), out);
dev_ctx, x, y, funcs::MinimumFunctor<T>(), out, axis);
}
template <typename T, typename Context>
......@@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::RemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::RemainderFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseRemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseRemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseRemainderFunctor<T>(), out, axis);
}
}
......@@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::FloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::FloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::FloorDivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseFloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseFloorDivideFunctor<T>(), out, axis);
}
}
......@@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwisePowFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwiseInversePowFunctor<T>(), out, axis);
}
}
......@@ -110,7 +110,7 @@ void HeavisideKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
dev_ctx, x, y, -1, funcs::ElementwiseHeavisideFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
}
} // namespace phi
......
......@@ -68,20 +68,15 @@ void LayerNormGradKernel(const Context& dev_ctx,
temp_norm.Resize(matrix_shape);
dev_ctx.template Alloc<T>(&temp_norm);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
x_tmp,
mean,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
&temp_norm);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, mean, funcs::SubtractFunctor<T>(), &temp_norm, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
temp_norm,
variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&temp_norm);
&temp_norm,
0);
}
if (d_bias) {
......@@ -90,8 +85,8 @@ void LayerNormGradKernel(const Context& dev_ctx,
}
if (d_scale) {
dev_ctx.template Alloc<T>(d_scale);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor<T>(), &temp);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, temp_norm, d_y, funcs::MultiplyFunctor<T>(), &temp, 0);
colwise_sum(dev_ctx, temp, d_scale);
}
......@@ -107,70 +102,45 @@ void LayerNormGradKernel(const Context& dev_ctx,
if (d_scale) {
// dy_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor<T>(), &temp);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, d_y, *scale, funcs::MultiplyFunctor<T>(), &temp, 1);
phi::Copy<Context>(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor<T>(), d_x, 0);
// dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
temp,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, temp, temp_norm, funcs::MultiplyFunctor<T>(), &temp, 0);
} else {
// dy_dx
phi::Copy<Context>(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor<T>(), d_x, 0);
// dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
d_y,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, d_y, temp_norm, funcs::MultiplyFunctor<T>(), &temp, 0);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
temp_norm,
temp_vec,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor<T>(), d_x);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, temp_norm, temp_vec, funcs::MultiplyFunctor<T>(), &temp, 0);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, *d_x, temp, funcs::SubtractFunctor<T>(), d_x, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
*d_x,
variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
d_x);
d_x,
0);
d_x->Resize(dx_dim);
}
}
......
......@@ -67,30 +67,30 @@ void LayerNormKernel(const Context& dev_ctx,
// get variance
phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor<T>(), &out);
phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor<T>(), &out, 0);
row_mean(dev_ctx, out, var);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor<T>(), &out);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubtractFunctor<T>(), &out, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
out,
*var,
0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&out);
&out,
0);
if (scale) {
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, out, *scale, 1, funcs::MultiplyFunctor<T>(), &out);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, out, *scale, funcs::MultiplyFunctor<T>(), &out, 1);
}
if (bias) {
phi::funcs::ElementwiseCompute<funcs::AddFunctor<T>, T, T>(
dev_ctx, out, *bias, 1, funcs::AddFunctor<T>(), &out);
phi::funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, out, *bias, funcs::AddFunctor<T>(), &out, 1);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(),
......
......@@ -32,7 +32,7 @@ namespace phi {
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, -1, binary_func, out); \
dev_ctx, x, y, binary_func, out); \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
......
......@@ -132,11 +132,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
DenseTensor tol_tensor;
tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
dev_ctx,
atol_tensor,
rtol_tensor,
-1,
GreaterElementFunctor<T>(),
&tol_tensor);
......@@ -151,17 +150,17 @@ void MatrixRankTolKernel(const Context& dev_ctx,
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result);
&compare_result,
axis);
} else {
funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>(
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::LessThanFunctor<T, int64_t>(),
&compare_result);
&compare_result,
axis);
}
phi::SumKernel<int64_t>(dev_ctx,
......
......@@ -31,20 +31,49 @@ namespace funcs {
enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 };
template <typename InT, typename OutT, int Arity>
template <int Index>
struct UseBroadcast {
template <typename ArgsT, typename Array1, typename Array2>
static HOSTDEVICE void Apply(
const std::vector<const DenseTensor *> &ins_tensor,
const ArgsT &args,
int64_t numel,
Array1 *ins_data,
Array2 *use_broadcast,
int *broadcast_num,
bool *all_elementwise) {
(*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data());
bool is_same_dim = ins_tensor[Index]->numel() == numel;
if (is_same_dim) {
(*use_broadcast)[Index] = false;
} else {
(*use_broadcast)[Index] = true;
(*broadcast_num)++;
}
*all_elementwise &= is_same_dim;
}
};
template <typename OutT, int Arity, typename Functor>
struct LoaderTypeClassifier {
public:
int64_t numel{0};
int vec_size{1};
int vec_size{4};
int broadcast_num{0};
bool all_elementwise{true};
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<bool, Arity> use_broadcast;
phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
LoaderTypeClassifier() {}
LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs) {
using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
ArgsT arg;
uint64_t out_addr = reinterpret_cast<uint64_t>((*outs)[0]->data<OutT>());
UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);
for (auto i = 1; i < outs->size(); ++i) {
PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(),
......@@ -56,165 +85,191 @@ struct LoaderTypeClassifier {
out_addr =
(out_addr | reinterpret_cast<uint64_t>((*outs)[i]->data<OutT>()));
}
int out_vec_size =
phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr));
uint64_t in_addr = static_cast<uint64_t>(0);
vec_size = std::min(
vec_size,
phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr)));
numel = (*outs)[0]->numel();
for (int i = 0; i < Arity; ++i) {
auto in_data = ins[i]->data<InT>();
ins_data[i] = (const _ptr_ InT *)(in_data);
bool is_same_dim = ins[i]->numel() == numel;
if (is_same_dim) {
use_broadcast[i] = false;
in_addr = (in_addr | reinterpret_cast<uint64_t>(in_data));
} else {
use_broadcast[i] = true;
broadcast_num++;
}
all_elementwise &= is_same_dim;
}
int in_vec_size = std::min(
4, phi::GetVectorizedSize<InT>(reinterpret_cast<InT *>(in_addr)));
vec_size = std::min(out_vec_size, in_vec_size);
UnrollerWithoutVecSize<UseBroadcast, Arity>::step(ins,
arg,
numel,
&ins_data,
&use_broadcast,
&broadcast_num,
&all_elementwise);
}
};
#ifndef PADDLE_WITH_XPU_KP
// Common broadcast/elementwise Loader.
template <typename T, int VecSize, int Arity, bool IsBoundary, int LoadType>
template <int Index, int VecSize, bool IsBoundary, int LoadType>
struct BroadcastDataLoader {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
static __device__ __forceinline__ void Apply(const Array1 &ins,
ArgsT *args,
const Array2 &configs,
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
if (use_broadcast[i]) {
kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, VecSize);
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
#ifdef PADDLE_WITH_XPU_KP
kps::Init<Type, ArgsT, Index, VecSize>(
args, static_cast<Type>(1.0f), read_lens);
if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]),
block_offset,
configs[Index],
numel,
read_lens);
} else {
kps::ReadData<T, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, VecSize);
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
num,
read_lens);
}
#else
kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]),
block_offset,
configs[Index],
numel,
VecSize);
}
// NOTE: If use if...else... with condition `use_broadcast[Index]` here,
// there will be some errs with clang12 while compiling in ROCm.
// When the compiler is upgraded, if...else... may be used.
if (!use_broadcast[Index]) {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
num,
VecSize);
}
#endif
}
};
/* BroadcastDataLoaders Partial specialization */
#ifndef PADDLE_WITH_XPU_KP
// Scalar elementwise Loader with consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
static __device__ __forceinline__ void Apply(const Array1 &ins,
ArgsT *args,
const Array2 &configs,
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = static_cast<T>(1);
std::get<Index>(args[idx]) = static_cast<Type>(1);
int index = thread_offset + idx;
if (index < numel) {
args[i][idx] = ins[i][index];
}
std::get<Index>(args[idx]) =
reinterpret_cast<const _ptr_ Type *>(ins[Index])[index];
}
}
}
};
// Vectorized elementwise Loader without consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
static __device__ __forceinline__ void Apply(const Array1 &ins,
ArgsT *args,
const Array2 &configs,
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
using VecType = phi::kps::details::VectorType<T, VecSize>;
VecType vec_temp[Arity];
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
using VecType = phi::kps::details::VectorType<Type, VecSize>;
VecType vec_temp;
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
const VecType *__restrict__ vec_input =
reinterpret_cast<const VecType *__restrict__>(ins[i]);
vec_temp[i] = vec_input[thread_offset];
reinterpret_cast<const VecType *__restrict__>(ins[Index]);
vec_temp = vec_input[thread_offset];
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = vec_temp[i].val[idx];
}
std::get<Index>(args[idx]) = vec_temp.val[idx];
}
}
};
// Common broadcast data loader.
template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, kBroadcast> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
uint32_t index_bc[Arity][VecSize];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
template <int Index, int VecSize>
struct BroadcastDataInit {
template <typename ArgsT>
static __device__ __forceinline__ void Apply(ArgsT *args) {
using Type = std::tuple_element_t<Index, ArgsT>;
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
index_bc[j][k] = 0;
args[j][k] = static_cast<T>(1);
std::get<Index>(args[k]) = static_cast<Type>(1);
}
}
};
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
template <int Index, int VecSize>
struct BroadcastDataSetter {
template <typename Array, typename ArgsT>
static __device__ __forceinline__ void Apply(const Array &ins,
ArgsT *args,
uint32_t index_bc[][VecSize]) {
using Type = std::tuple_element_t<Index, ArgsT>;
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k;
if (IsBoundary) {
if (idx == numel) break;
}
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].rank) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
std::get<Index>(args[k]) =
reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[Index][k]];
}
}
};
#pragma unroll
for (int j = 0; j < Arity; ++j) {
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
args[j][k] = ins[j][index_bc[j][k]];
}
}
#endif
// static broadcast unroller
template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
typename Func,
bool IsBoundary,
int LoadType,
int VecSize,
int End,
int Begin = 0>
struct BcUnroller {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {
Func<Begin, VecSize, IsBoundary, LoadType>::Apply(
std::forward<Args>(args)...);
BcUnroller<Func, IsBoundary, LoadType, VecSize, End, Begin + 1>::step(
args...);
}
};
#endif
template <typename InT,
typename OutT,
template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
typename Func,
bool IsBoundary,
int LoadType,
int VecSize,
int End>
struct BcUnroller<Func, IsBoundary, LoadType, VecSize, End, End> {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {}
};
template <typename OutT,
typename Functor,
int Arity,
int NumOuts,
......@@ -222,59 +277,69 @@ template <typename InT,
bool IsBoundary,
int LoadType>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
const phi::Array<const _ptr_ char *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
const phi::Array<bool, Arity> &use_broadcast,
const uint32_t numel,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
int read_lens,
Functor func) {
__simd__ InT args[Arity][VecSize];
using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
__simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#ifdef PADDLE_WITH_XPU_KP
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
#else
if (LoadType == kBroadcast) {
uint32_t index_bc[Arity][VecSize] = {0};
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
if (use_broadcast[i]) {
kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, read_lens);
} else {
kps::ReadData<InT, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, read_lens);
for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k;
if (IsBoundary && idx == numel) break;
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].rank) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
}
#else
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, LoadType>()(
args, ins, configs, use_broadcast, block_offset, num, numel);
}
Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
} else {
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
}
#endif
constexpr bool kCallElementwiseAny =
phi::funcs::FunctionTraits<Functor>::has_pointer_args;
phi::funcs::ElementwisePrimitiveCaller<InT,
ConditionalT<OutT, NumOuts>,
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Arity,
kCallElementwiseAny>()(
func, args, result, read_lens);
ArgsT,
Arity>()(func, args, result, read_lens);
phi::funcs::
ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num, read_lens);
}
template <typename Functor,
typename InT,
typename OutT,
int Arity,
int NumOuts,
int VecSize,
int LoadType>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<const _ptr_ char *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast,
phi::Array<bool, Arity> use_broadcast,
uint32_t numel,
phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset,
......@@ -285,8 +350,7 @@ __global__ void VectorizedBroadcastKernel(
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
for (; block_offset < main_offset; block_offset += stride) {
VectorizedBroadcastKernelImpl<InT,
OutT,
VectorizedBroadcastKernelImpl<OutT,
Functor,
Arity,
NumOuts,
......@@ -304,8 +368,7 @@ __global__ void VectorizedBroadcastKernel(
}
int num = numel - block_offset;
if (num > 0) {
VectorizedBroadcastKernelImpl<InT,
OutT,
VectorizedBroadcastKernelImpl<OutT,
Functor,
Arity,
NumOuts,
......@@ -324,8 +387,7 @@ __global__ void VectorizedBroadcastKernel(
#else
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
if (block_offset < main_offset) {
VectorizedBroadcastKernelImpl<InT,
OutT,
VectorizedBroadcastKernelImpl<OutT,
Functor,
Arity,
NumOuts,
......@@ -341,8 +403,7 @@ __global__ void VectorizedBroadcastKernel(
read_lens,
func);
} else {
VectorizedBroadcastKernelImpl<InT,
OutT,
VectorizedBroadcastKernelImpl<OutT,
Functor,
Arity,
NumOuts,
......@@ -361,19 +422,14 @@ __global__ void VectorizedBroadcastKernel(
#endif
}
template <typename InT,
typename OutT,
typename Func,
int Arity,
int NumOuts,
int VecSize>
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
void LaunchBroadcastKernel(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Func func,
Functor func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const LoaderTypeClassifier<InT, OutT, Arity> &loader_classifier) {
const LoaderTypeClassifier<OutT, Arity, Functor> &loader_classifier) {
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
......@@ -388,7 +444,7 @@ void LaunchBroadcastKernel(
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, false>
VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
loader_classifier.use_broadcast,
......@@ -409,8 +465,7 @@ void LaunchBroadcastKernel(
int tail_tid = numel % (VecSize * threads);
if (loader_classifier.all_elementwise) {
VectorizedBroadcastKernel<Func,
InT,
VectorizedBroadcastKernel<Functor,
OutT,
Arity,
NumOuts,
......@@ -427,7 +482,7 @@ void LaunchBroadcastKernel(
func);
} else if (loader_classifier.broadcast_num > (Arity >> 1)) {
constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed;
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, type_>
VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, type_>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
loader_classifier.use_broadcast,
......@@ -438,7 +493,7 @@ void LaunchBroadcastKernel(
VecSize,
func);
} else {
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, kMixed>
VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, kMixed>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
loader_classifier.use_broadcast,
......@@ -471,94 +526,49 @@ HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
return dst_idx;
}
template <typename T, int VecSize, bool IsBoundary>
HOSTDEVICE static void ReadVecDataWithInt64Index(
const T *in,
template <int N>
struct MaxWithOne {
static constexpr auto kValue = (N >= 1 ? N : 1);
};
template <int Index, int VecSize>
struct ReadVecDataWithInt64Index {
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
static __device__ __forceinline__ void Apply(
const Array1 &in,
ArgsT *args,
int64_t idx,
bool need_broadcast,
const Array2 &need_broadcast,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
const Array3 &dst_strides,
int rank,
int n,
phi::AlignedVector<T, VecSize> *out) {
if (IsBoundary) {
for (int i = 0; i < n; ++i) {
(*out)[i] =
in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
bool is_boundary) {
using Type = std::tuple_element_t<Index, ArgsT>;
if (is_boundary) {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
idx + i, src_strides, dst_strides[Index], rank)];
}
} else {
if (!need_broadcast) {
phi::Load<T, VecSize>(in + idx, out);
if (!need_broadcast[Index]) {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, false>(
args, reinterpret_cast<const _ptr_ Type *>(in[Index]) + idx, 1);
} else {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
(*out)[i] =
in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
idx + i, src_strides, dst_strides[Index], rank)];
}
}
}
}
template <typename InT,
typename OutT,
typename Functor,
int VecSize,
int NumIns>
struct ApplyFunctorWithInt64IndexHelper {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i);
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 0> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor());
}
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 1> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor(ins_vec[0][i]));
}
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 2> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor(ins_vec[0][i], ins_vec[1][i]));
}
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 3> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(
functor(ins_vec[0][i], ins_vec[1][i], ins_vec[2][i]));
}
};
template <int N>
struct MaxWithOne {
static constexpr auto kValue = (N >= 1 ? N : 1);
};
template <typename InT,
typename OutT,
typename Functor,
int VecSize,
int NumIns>
template <typename OutT, typename Functor, int VecSize, int NumIns>
__global__ void BroadcastKernelWithInt64Index(
phi::Array<const InT *, MaxWithOne<NumIns>::kValue> ins,
const phi::Array<const _ptr_ char *__restrict__, MaxWithOne<NumIns>::kValue>
&ins,
OutT *out,
phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
MaxWithOne<NumIns>::kValue> ins_strides,
......@@ -572,70 +582,34 @@ __global__ void BroadcastKernelWithInt64Index(
int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
int64_t limit = numel - VecSize;
phi::Array<phi::AlignedVector<InT, VecSize>, MaxWithOne<NumIns>::kValue>
ins_vec;
using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
ArgsT args[VecSize];
phi::AlignedVector<OutT, VecSize> out_vec;
for (; idx <= limit; idx += stride) {
#pragma unroll
for (int i = 0; i < NumIns; ++i) {
ReadVecDataWithInt64Index<InT, VecSize, false>(ins[i],
idx,
need_broadcasts[i],
out_strides,
ins_strides[i],
rank,
VecSize,
&ins_vec[i]);
}
Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, false);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] = ApplyFunctorWithInt64IndexHelper<InT,
OutT,
Functor,
VecSize,
NumIns>::Run(ins_vec.Get(),
functor,
i);
out_vec[i] = static_cast<OutT>(Apply(functor, args[i]));
}
phi::Store<OutT, VecSize>(out_vec, out + idx);
}
if (idx < numel) {
int remain = numel - idx; // remain is always less than VecSize, therefore
// `int` is enough here
#pragma unroll
for (int i = 0; i < NumIns; ++i) {
ReadVecDataWithInt64Index<InT, VecSize, true>(ins[i],
idx,
need_broadcasts[i],
out_strides,
ins_strides[i],
rank,
remain,
&ins_vec[i]);
}
Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, true);
for (int i = 0; i < remain; ++i) {
out[idx + i] =
ApplyFunctorWithInt64IndexHelper<InT,
OutT,
Functor,
VecSize,
NumIns>::Run(ins_vec.Get(),
functor,
i);
out_vec[idx + i] = static_cast<OutT>(Apply(functor, args[i]));
}
}
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper {
static void Run(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
......@@ -647,9 +621,8 @@ struct LaunchBroadcastKernelWithInt64IndexHelper {
}
};
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
template <typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
Functor,
Arity,
/*NumOuts=*/1,
......@@ -659,10 +632,9 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
std::vector<DenseTensor *> *outs,
int axis,
Functor functor) {
phi::Array<const InT *, MaxWithOne<Arity>::kValue> ins_ptrs;
for (int i = 0; i < Arity; ++i) {
ins_ptrs[i] = ins[i]->data<InT>();
}
phi::Array<const _ptr_ char *__restrict__, MaxWithOne<Arity>::kValue>
ins_ptrs;
UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_ptrs);
auto *out_tensor = (*outs)[0];
auto *out_ptr = ctx.Alloc<OutT>(out_tensor);
......@@ -734,7 +706,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
BroadcastKernelWithInt64Index<InT, OutT, Functor, VecSize, Arity>
BroadcastKernelWithInt64Index<OutT, Functor, VecSize, Arity>
<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
......@@ -843,58 +815,24 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
};
#endif
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
template <typename OutT, typename Functor, int kArity, int NumOuts = 1>
void BroadcastKernelForDifferentVecSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(
ins.size(),
kArity,
phi::errors::InvalidArgument("The number of inputs is expected to be "
"equal to the "
"arity of functor. But received: the "
"number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_LE(
kArity,
3,
phi::errors::InvalidArgument("Currently only broadcast of ternary is "
"supported "
"and verified, but received %d.",
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
phi::errors::InvalidArgument("Number of outputs shall equal to number "
"of functions, "
"but number of outputs is %d, of "
"functions is %d.",
outs->size(),
NumOuts));
#ifndef PADDLE_WITH_XPU_KP
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
bool use_int64_index_kernel =
kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
if (use_int64_index_kernel) {
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs);
auto loader_classifier =
LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
switch (loader_classifier.vec_size) {
case VecSizeL: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
LaunchBroadcastKernelWithInt64IndexHelper<OutT,
Functor,
kArity,
NumOuts,
......@@ -906,8 +844,7 @@ void BroadcastKernelForDifferentVecSize(
break;
}
case VecSizeM: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
LaunchBroadcastKernelWithInt64IndexHelper<OutT,
Functor,
kArity,
NumOuts,
......@@ -919,8 +856,7 @@ void BroadcastKernelForDifferentVecSize(
break;
}
case VecSizeS: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
LaunchBroadcastKernelWithInt64IndexHelper<OutT,
Functor,
kArity,
NumOuts,
......@@ -949,7 +885,7 @@ void BroadcastKernelForDifferentVecSize(
phi::errors::InvalidArgument(
"XPU only support inputs is 2, but received %d", ins.size()));
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>();
auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
......@@ -968,7 +904,8 @@ void BroadcastKernelForDifferentVecSize(
bool is_optimize = configs[0].cmp_type != type;
int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs);
auto loader_classifier =
LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
if (!loader_classifier.all_elementwise) {
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
......@@ -991,17 +928,17 @@ void BroadcastKernelForDifferentVecSize(
#endif
switch (loader_classifier.vec_size) {
case VecSizeL: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>(
LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
ctx, ins, outs, func, configs, loader_classifier);
break;
}
case VecSizeM: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>(
LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
ctx, ins, outs, func, configs, loader_classifier);
break;
}
case VecSizeS: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>(
LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
ctx, ins, outs, func, configs, loader_classifier);
break;
}
......@@ -1013,18 +950,36 @@ void BroadcastKernelForDifferentVecSize(
}
}
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
template <typename OutT, typename Functor, int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
Functor func,
int axis = -1) {
// When there are multiple inputs, the outputs's rank should be equal the
// maximum rank of all inputs.
using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity = Traits::arity;
PADDLE_ENFORCE_EQ(
ins.size(),
kArity,
phi::errors::InvalidArgument("The number of inputs is expected to be "
"equal to the "
"arity of functor. But received: the "
"number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
phi::errors::InvalidArgument("Number of outputs shall equal to number "
"of functions, "
"but number of outputs is %d, of "
"functions is %d.",
outs->size(),
NumOuts));
int max_rank = 0;
int min_rank = phi::DDim::kMaxRank;
for (auto *in : ins) {
......@@ -1037,7 +992,7 @@ void BroadcastKernel(const KPDevice &ctx,
max_rank = std::max(max_rank, (*outs)[0]->dims().size());
}
axis = axis == -1 ? max_rank - min_rank : axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
ctx, ins, outs, axis, func);
}
......@@ -1045,14 +1000,14 @@ template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
DenseTensor *z,
int axis = -1) {
std::vector<const DenseTensor *> ins = {&x, &y};
std::vector<DenseTensor *> outs = {z};
dev_ctx.template Alloc<OutType>(z);
BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
dev_ctx, ins, &outs, axis, func);
BroadcastKernel<OutType, Functor, 1>(dev_ctx, ins, &outs, func, axis);
}
template <typename DeviceContext,
......@@ -1067,7 +1022,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
auto x_dims = x.dims();
auto y_dims = y.dims();
dev_ctx.template Alloc<T>(z);
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
}
#else
......@@ -1085,10 +1040,10 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
auto y_dims = y.dims();
dev_ctx.template Alloc<T>(z);
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T>(
dev_ctx, x, y, axis, InverseFunctor(), z);
dev_ctx, x, y, InverseFunctor(), z, axis);
}
}
#endif
......
......@@ -191,25 +191,19 @@ __global__ void VectorizedRandomGenerator(const size_t n,
}
template <typename T>
__global__ void DropOutNdForwardKernel(
const size_t n,
__global__ void VectorizedGeneratorMask(const size_t n,
uint64_t seed,
const float dropout_prob,
const T* src,
uint8_t* mask,
uint64_t increment,
size_t main_offset,
DstFunctor<T> dst_functor,
MaskFunctor<T> mask_functor,
T* y,
int64_t N,
kps::details::BroadcastConfig broadcast_config,
const uint64_t* seed_ptr) {
// Vectorized Generate Mask
// kCount is 4 for curand_uniform4 is used
if (seed_ptr) {
seed = seed_ptr[0];
}
if (seed_ptr) seed = seed_ptr[0];
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
......@@ -259,21 +253,6 @@ __global__ void DropOutNdForwardKernel(
kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder);
}
// Broadcast mask data and do elementwise operaiton with DstFunctor
CUDA_KERNEL_LOOP(i, N) {
uint32_t offset = 0u;
uint32_t idx = i;
// Use (j < phi::DDim::kMaxRank) conditiion rather than
// (j < broadcast_config.rank) for (#pragma unroll)
#pragma unroll
for (int j = 0; j < phi::DDim::kMaxRank; ++j) {
if (j == broadcast_config.rank) break;
auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx);
idx = fast_divmoder.val[0];
offset += broadcast_config.strides[j] * fast_divmoder.val[1];
}
y[i] = dst_functor(src[i], mask[offset]);
}
}
template <typename T, typename MT>
......@@ -347,18 +326,6 @@ void DropoutFwGPUKernelDriver(
size / (block_size * kVecSize) * (block_size * kVecSize);
if (is_dropout_nd) {
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<int64_t> out_dims =
std::move(phi::vectorize<int64_t>(x.dims()));
std::vector<int64_t> in_dims =
std::move(phi::vectorize<int64_t>(mask->dims()));
std::reverse(out_dims.begin(), out_dims.end());
std::reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config(
out_dims, in_dims, x.dims().size());
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
seed,
......@@ -371,20 +338,22 @@ void DropoutFwGPUKernelDriver(
const uint64_t* seed_ptr =
copy_in_kernel ? seed->data<uint64_t>() : nullptr;
DropOutNdForwardKernel<T>
VectorizedGeneratorMask<T>
<<<grid_size, block_size, 0, stream>>>(size,
seed_data,
dropout_prob,
x_data,
mask_data,
increment,
main_offset,
dst_functor,
mask_functor,
y_data,
y->numel(),
broadcast_config,
seed_ptr);
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<const phi::DenseTensor*> ins = {&x, mask};
std::vector<phi::DenseTensor*> outs = {y};
phi::funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, dst_functor);
} else {
bool copy_in_kernel = GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
......@@ -458,43 +427,28 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
// y = factor * x
ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
} else {
phi::DenseTensor broadcasted_mask;
if (is_dropout_nd) {
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
}
std::vector<const phi::DenseTensor*> ins = {
&grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
if (upscale_in_train) {
if (dropout_prob == 1.0f) {
if (upscale_in_train && dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#else
cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#endif
} else {
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
phi::funcs::ElementwiseKernel<T>(
MT factor = upscale_in_train
? static_cast<MT>(1.0f / (1.0f - dropout_prob))
: static_cast<MT>(1.0f);
std::vector<const phi::DenseTensor*> ins = {&grad_y, &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
if (is_dropout_nd) {
phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
}
} else {
MT factor = static_cast<MT>(1.0f);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
}
}
}
}
} // namespace funcs
......
......@@ -35,7 +35,6 @@ namespace kps = phi::kps;
namespace phi {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
......@@ -369,9 +368,9 @@ template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const CPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
DenseTensor *z,
int axis = -1) {
dev_ctx.Alloc<OutType>(z);
auto x_dims = x.dims();
auto y_dims = y.dims();
......@@ -508,10 +507,26 @@ struct Unroller<Func, VecSize, End, End> {
static HOSTDEVICE inline void step(Args &&...args) {}
};
// static unroller without VecSize for broadcast
template <template <int Index> typename Func, int End, int Begin = 0>
struct UnrollerWithoutVecSize {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {
Func<Begin>::Apply(std::forward<Args>(args)...);
UnrollerWithoutVecSize<Func, End, Begin + 1>::step(args...);
}
};
template <template <int Index> typename Func, int End>
struct UnrollerWithoutVecSize<Func, End, End> {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {}
};
template <int Index, int VecSize>
struct Loader {
template <typename Array, typename ArgsT>
static __device__ void Apply(const Array &in,
static __device__ __forceinline__ void Apply(const Array &in,
ArgsT *args,
kps::IndexType offset,
int num,
......@@ -536,7 +551,7 @@ struct Loader {
}
};
template <int Index, int VecSize>
template <int Index>
struct InputSetter {
template <typename Array>
static HOSTDEVICE void Apply(
......@@ -545,7 +560,7 @@ struct InputSetter {
}
};
template <int Index, int VecSize>
template <int Index>
struct VecSizeGetter {
template <typename ArgsT>
static HOSTDEVICE void Apply(const std::vector<const DenseTensor *> &ins,
......@@ -569,8 +584,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
int vec_size = 4;
uint64_t addr = static_cast<uint64_t>(0);
ArgsT arg;
// The Arg VecSize=1 is to match the Unroller template.
Unroller<VecSizeGetter, 1, Arity>::step(ins, arg, &vec_size);
UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);
for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
addr = (addr | reinterpret_cast<uint64_t>((*iter)->data<OutT>()));
}
......@@ -580,73 +594,6 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
return vec_size;
}
template <typename InT,
typename OutT,
int VecSize,
typename Functor,
int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, Arity, Functor>(
result, args, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, Functor>(result, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, Functor>(
result, args[0], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], func, read_lens);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
namespace detail {
template <class F, class Tuple, std::size_t... Index>
// GCC/Clang need the decltype() return type
......@@ -802,7 +749,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data);
UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_data);
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
}
......
......@@ -112,8 +112,8 @@ class AttnMatMul {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
}
}
......
......@@ -258,12 +258,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::AddFunctor<T>(),
elewise_add_axis);
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......@@ -435,12 +434,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::AddFunctor<T>(),
elewise_add_axis);
phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......
......@@ -106,8 +106,8 @@ struct DirichletSampler<GPUContext, T> {
{new_shape.size() - 1},
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
}
};
} // namespace phi
......
......@@ -37,8 +37,7 @@ void DivideGradKernel(const Context& dev_ctx,
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
GetGradXAndYOut<T>(dev_ctx,
place,
axis,
ins,
......@@ -48,11 +47,11 @@ void DivideGradKernel(const Context& dev_ctx,
funcs::DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
}
}
......
......@@ -35,7 +35,7 @@ void ReduceWrapper(const GPUContext &dev_ctx,
dev_ctx, *src, dst, kps::IdentityFunctor<T>(), reduce_dims);
}
template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXAndYOut(const GPUContext &dev_ctx,
const Place &place,
int axis,
......@@ -67,8 +67,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx,
outs = {&tmp_dx, &tmp_dy};
}
funcs::BroadcastKernel<ET, T, T, decltype(func), 2>(
dev_ctx, ins, &outs, axis, func);
funcs::BroadcastKernel<T, decltype(func), 2>(dev_ctx, ins, &outs, func, axis);
if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
......@@ -80,7 +79,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx,
}
}
template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXOrYOut(const GPUContext &dev_ctx,
const Place &place,
int axis,
......@@ -100,7 +99,7 @@ void GetGradXOrYOut(const GPUContext &dev_ctx,
outs = {dxy};
}
funcs::BroadcastKernel<ET, T, T>(dev_ctx, ins, &outs, axis, func);
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func, axis);
if (dxy->dims() != dout.dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dxy, dxy);
}
......@@ -342,8 +341,7 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx,
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &out, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
GetGradXAndYOut<T>(dev_ctx,
place,
axis,
ins,
......@@ -353,11 +351,11 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx,
funcs::DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &out, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
}
}
......@@ -380,8 +378,7 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx,
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y, &x};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
GetGradXAndYOut<T>(dev_ctx,
place,
axis,
ins,
......@@ -391,11 +388,11 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx,
funcs::MultiplyGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &x};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor<T>());
}
}
......
......@@ -38,8 +38,7 @@ void MaximumGradKernel(const Context& dev_ctx,
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
GetGradXAndYOut<T>(dev_ctx,
place,
axis,
ins,
......@@ -49,11 +48,11 @@ void MaximumGradKernel(const Context& dev_ctx,
funcs::MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MaxGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MaxGradYFunctor<T>());
}
}
......@@ -69,8 +68,7 @@ void MinimumGradKernel(const Context& dev_ctx,
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
GetGradXAndYOut<T>(dev_ctx,
place,
axis,
ins,
......@@ -80,11 +78,11 @@ void MinimumGradKernel(const Context& dev_ctx,
funcs::MinGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MinGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MinGradYFunctor<T>());
}
}
......
......@@ -74,8 +74,7 @@ void ExpandAsKernel(const Context& ctx,
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
ctx, ins, &outs, -1, kps::IdentityFunctor<T>());
phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
}
} // namespace phi
......
......@@ -73,8 +73,7 @@ void ExpandKernel(const Context& ctx,
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
ctx, ins, &outs, -1, kps::IdentityFunctor<T>());
phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
}
} // namespace phi
......
......@@ -407,11 +407,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
dev_ctx,
atol_tensor,
rtol_tensor,
-1,
GreaterElementFunctor<T>(),
&tol_tensor);
......@@ -421,12 +420,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
compare_result.Resize(detail::NewAxisDim(dim_out, k));
dev_ctx.template Alloc<int64_t>(&compare_result);
int axis = -1;
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int64_t>(
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result);
......
......@@ -78,8 +78,8 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
// 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_y, new_in_tensor};
std::vector<phi::DenseTensor*> equal_outputs = {&equal_out_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. equal_count = reduceSum(equal_out)
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs::
......@@ -95,15 +95,15 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
std::vector<const phi::DenseTensor*> mul_inputs = {&new_dout,
&equal_out_tensor};
std::vector<phi::DenseTensor*> mul_outputs = {&equal_out_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
// 4. dx = Div(dx, equal_out)
std::vector<const phi::DenseTensor*> grad_inputs = {&equal_out_tensor,
equal_count};
std::vector<phi::DenseTensor*> grad_outputs = {new_dx_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, grad_inputs, &grad_outputs, 0, funcs::DivideFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, grad_inputs, &grad_outputs, funcs::DivideFunctor<T>(), 0);
delete equal_out;
delete equal_count;
}
......
......@@ -28,7 +28,7 @@
namespace phi {
template <typename InT, typename Functor>
template <typename Functor>
void ReduceGrad(const GPUContext& dev_ctx,
DenseTensor* d_out,
DenseTensor* d_x,
......@@ -36,14 +36,13 @@ void ReduceGrad(const GPUContext& dev_ctx,
Functor functor) {
std::vector<const DenseTensor*> inputs = {d_out};
std::vector<DenseTensor*> outputs = {d_x};
PD_VISIT_ALL_TYPES(
out_dtype, "BroadcastKernel", ([&] {
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, InT, data_t>(
dev_ctx, inputs, &outputs, 0, functor);
PD_VISIT_ALL_TYPES(out_dtype, "BroadcastKernel", ([&] {
funcs::BroadcastKernel<data_t>(
dev_ctx, inputs, &outputs, functor, 0);
}));
}
template <typename T, typename OutT, typename Context, typename Functor>
template <typename OutT, typename Context, typename Functor>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
......@@ -79,8 +78,7 @@ void ReduceGradKernel(const Context& dev_ctx,
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>(
dev_ctx, inputs, &outputs, 0, functor);
funcs::BroadcastKernel<OutT>(dev_ctx, inputs, &outputs, functor, 0);
}
} // namespace phi
......
......@@ -62,14 +62,14 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
// 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x};
std::vector<phi::DenseTensor*> equal_outputs = {equal_out};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out};
std::vector<phi::DenseTensor*> mul_outputs = {x_grad};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
delete equal_out;
}
} // namespace phi
......
......@@ -53,8 +53,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
std::vector<DenseTensor*> outputs = {x_grad};
using MPType = typename kps::details::MPTypeTrait<T>::Type;
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, T>(
dev_ctx, inputs, &outputs, 0, kps::DivideFunctor<T, MPType>(reduce_num));
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, kps::DivideFunctor<T, MPType>(reduce_num), 0);
}
} // namespace phi
......
......@@ -62,14 +62,14 @@ void ReduceMinGradKernel(const Context& dev_ctx,
// 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x};
std::vector<phi::DenseTensor*> equal_outputs = {equal_out};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out};
std::vector<phi::DenseTensor*> mul_outputs = {x_grad};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
delete equal_out;
}
} // namespace phi
......
......@@ -48,7 +48,7 @@ void ReduceSumGradKernel(const Context& dev_ctx,
// call ReduceGrad
dev_ctx.Alloc(x_grad, x.dtype());
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
phi::ReduceGrad<kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&new_out_grad,
x_grad,
......
......@@ -46,8 +46,7 @@ void SquaredL2NormGradKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, phi::DoubleMulFunctor<T>());
}
} // namespace phi
......
......@@ -78,8 +78,8 @@ void TileKernel(const Context& dev_ctx,
tmp_out.Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(&tmp_out);
std::vector<DenseTensor*> outs = {&tmp_out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, kps::IdentityFunctor<T>(), i);
tmp_out.Resize(out_dims);
new_x = tmp_out;
}
......@@ -89,8 +89,8 @@ void TileKernel(const Context& dev_ctx,
out->Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(out);
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, kps::IdentityFunctor<T>(), i);
out->Resize(out_dims);
}
}
......
......@@ -91,8 +91,7 @@ struct BinaryOperation {
DenseTensor* output) {
std::vector<const DenseTensor*> ins{&lhs, &rhs};
std::vector<DenseTensor*> outs{output};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, BinaryFunctor<T>());
phi::funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, BinaryFunctor<T>(), 0);
}
};
......
......@@ -90,16 +90,16 @@ void ComplexKernel(const Context& dev_ctx,
// facility functions
#if defined(__NVCC__) || defined(__HIPCC__)
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out);
dev_ctx, x, y, RealAndImagToComplexFunctor<T>(), out);
#else
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out);
dev_ctx, x, y, RealAndImagToComplexFunctor<T>(), out);
} else {
phi::funcs::ElementwiseCompute<ImagAndRealToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), out);
dev_ctx, x, y, ImagAndRealToComplexFunctor<T>(), out);
}
#endif
}
......
......@@ -76,15 +76,15 @@ void AddDoubleGradImpl(const Context& dev_ctx,
auto ddy_dims = ddy_safe.dims();
if (ddx_dims.size() >= ddy_dims.size()) {
funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor<T>(), ddout);
dev_ctx, ddx_safe, ddy_safe, funcs::AddFunctor<T>(), ddout, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseAddFunctor<T>, T>(
dev_ctx,
ddx_safe,
ddy_safe,
axis,
funcs::InverseAddFunctor<T>(),
ddout);
ddout,
axis);
}
}
}
......@@ -107,7 +107,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx,
dev_ctx.template Alloc<T>(ddout);
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor<T>(), ddout);
dev_ctx, ddx_safe, ddy_safe, funcs::SubtractFunctor<T>(), ddout, axis);
}
}
......
......@@ -39,10 +39,10 @@ namespace phi {
auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \
funcs::ElementwiseCompute<funcs::name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::name##Functor<T>(), out); \
dev_ctx, x, y, funcs::name##Functor<T>(), out, axis); \
} else { \
funcs::ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out); \
dev_ctx, x, y, funcs::Inverse##name##Functor<T>(), out, axis); \
} \
} \
}
......@@ -62,8 +62,8 @@ namespace phi {
inputs.emplace_back(&y); \
outputs.emplace_back(out); \
dev_ctx.template Alloc<T>(out); \
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
funcs::BroadcastKernel<T>( \
dev_ctx, inputs, &outputs, funcs::name##Functor<T>(), axis); \
}
template <typename T, typename Context>
......@@ -72,8 +72,8 @@ void FMaxKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>(
dev_ctx, x, y, -1, funcs::FMaxFunctor<T>(), out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T>(
dev_ctx, x, y, funcs::FMaxFunctor<T>(), out);
}
template <typename T, typename Context>
......@@ -82,8 +82,8 @@ void FMinKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, -1, funcs::FMinFunctor<T>(), out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T>(
dev_ctx, x, y, funcs::FMinFunctor<T>(), out);
}
} // namespace phi
......@@ -153,12 +153,8 @@ void SetValueCompute(const Context& dev_ctx,
slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>(dev_ctx,
slice_tensor,
*value_tensor,
-1,
SubFunctor<T>(),
&slice_tensor);
phi::funcs::ElementwiseCompute<SubFunctor<T>, T>(
dev_ctx, slice_tensor, *value_tensor, SubFunctor<T>(), &slice_tensor);
} else {
DenseTensor value_t(dtype);
auto value_dims = phi::make_ddim(shape);
......@@ -166,8 +162,8 @@ void SetValueCompute(const Context& dev_ctx,
value_t.Resize(value_dims);
dev_ctx.template Alloc<T>(&value_t);
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>(
dev_ctx, slice_tensor, value_t, -1, SubFunctor<T>(), &slice_tensor);
phi::funcs::ElementwiseCompute<SubFunctor<T>, T>(
dev_ctx, slice_tensor, value_t, SubFunctor<T>(), &slice_tensor);
}
slice_tensor.Resize(slice_dims);
......
......@@ -204,7 +204,6 @@ void SetValueImpl(const Context& dev_ctx,
dev_ctx,
slice_tensor,
value,
-1,
funcs::SubtractFunctor<T>(),
&slice_tensor);
} else {
......@@ -212,7 +211,6 @@ void SetValueImpl(const Context& dev_ctx,
dev_ctx,
slice_tensor,
value,
-1,
funcs::InverseSubtractFunctor<T>(),
&slice_tensor);
}
......
......@@ -35,8 +35,7 @@ namespace phi {
funcs::Bitwise##op_type##Functor<T> func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, ins, &outs, -1, func); \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
}
DEFINE_BITWISE_KERNEL(And)
......
......@@ -55,8 +55,7 @@ inline void CompareKernelImpl(const Context& ctx,
ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, bool>(
ctx, ins, &outs, axis, Functor());
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}
#ifndef PADDLE_WITH_XPU_KP
......
......@@ -69,8 +69,8 @@ void HeavisideKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, -1, funcs::ElementwiseHeavisideFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::ElementwiseHeavisideFunctor<T>());
}
// Create the definition of Pow
......
......@@ -31,14 +31,11 @@ namespace phi {
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
using InT = typename funcs::Logical##type##Functor<T>::ELEMENT_TYPE; \
using OutT = bool; \
dev_ctx.template Alloc<bool>(out); \
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, InT, OutT>( \
dev_ctx, ins, &outs, -1, binary_func); \
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
......@@ -50,15 +47,11 @@ template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
using InT = typename funcs::LogicalNotFunctor<T>::ELEMENT_TYPE;
using OutT = bool;
dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<ElementwiseType::kUnary, InT, OutT>(
dev_ctx, ins, &outs, -1, unary_func);
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
}
} // namespace phi
......
......@@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MaximumFunctor<T>(), out);
dev_ctx, x, y, funcs::MaximumFunctor<T>(), out, axis);
}
template <typename T, typename Context>
......@@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MinimumFunctor<T>(), out);
dev_ctx, x, y, funcs::MinimumFunctor<T>(), out, axis);
}
template <typename T, typename Context>
......@@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::RemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::RemainderFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseRemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseRemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseRemainderFunctor<T>(), out, axis);
}
}
......@@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::FloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::FloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::FloorDivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseFloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseFloorDivideFunctor<T>(), out, axis);
}
}
......@@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwisePowFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwiseInversePowFunctor<T>(), out, axis);
}
}
......
......@@ -36,8 +36,8 @@ void MaximumRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, funcs::MaximumFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::MaximumFunctor<T>(), axis);
}
template <typename T, typename Context>
......@@ -54,8 +54,8 @@ void MinimumRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, funcs::MinimumFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::MinimumFunctor<T>(), axis);
}
template <typename T, typename Context>
......@@ -72,8 +72,8 @@ void RemainderRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, funcs::RemainderFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::RemainderFunctor<T>(), axis);
}
template <typename T, typename Context>
......@@ -90,8 +90,8 @@ void FloorDivideRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, funcs::FloorDivideFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::FloorDivideFunctor<T>(), axis);
}
template <typename T, typename Context>
......@@ -108,8 +108,8 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, funcs::ElementwisePowFunctor<T>());
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::ElementwisePowFunctor<T>(), axis);
}
} // namespace phi
......
......@@ -255,6 +255,18 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) {
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
*/
template <typename T, typename ArgsT, int Index, int NX>
__device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
#pragma unroll
for (int i = 0; i < NX; i++) {
std::get<Index>(dst[i]) = init_data;
}
}
/**
* @brief Read 1D data from global memory to register. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
......@@ -307,6 +319,23 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
/**
* @brief Read 1D data from global memory to register.
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/
template <typename T, int NX, int NY, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
......@@ -347,9 +376,8 @@ __device__ __forceinline__ void ReadData(T* dst,
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
......@@ -369,7 +397,7 @@ template <typename T,
__device__ __forceinline__ void ReadData(ArgsT* dst,
const T* __restrict__ src,
int num,
int read_lens) {
int read_lens = 0) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
......@@ -743,7 +771,6 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
......@@ -788,6 +815,67 @@ __device__ __forceinline__ void ReadDataBc(
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T* __restrict__ src,
uint32_t block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.rank) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
std::get<Index>(dst[nx]) = src[index_src];
}
}
/**
* @brief Initialize register with data index.
*
......
......@@ -1211,6 +1211,65 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T _global_ptr_* src,
int block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
int thread_offset = block_offset + core_id() * read_lens;
__local__ T in_temp[NX];
if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T>(in_temp, src, thread_offset, config, read_lens);
} else {
ReadDataBcCanNotCmp<T, IsBoundary>(
in_temp, src, thread_offset, config, total_num_output, read_lens);
}
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
std::get<Index>(dst[idx]) = in_temp[idx];
}
}
/**
* @brief Initialize register with data index.
*
......
......@@ -89,8 +89,7 @@ void TestCase(const phi::GPUContext& dev_ctx,
d_in1.get(), d_in2.get(), d_in3.get()};
std::vector<phi::DenseTensor*> outputs{d_out.get()};
for (int i = 0; i < times; ++i) {
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx, inputs, &outputs, -1, compute);
phi::funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, compute);
}
dev_ctx.Wait();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册