未验证 提交 d611e48c 编写于 作者: B Bo Zhang 提交者: GitHub

Dropout optimize & clean broadcast inT and ElementwiseType (#52969)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* clean ElementwiseT and InT for BroadcastKernel

* default axis and clean inT

* remove redundant fast divmod computation

* optimize drop_nd & drop_nd_grad

* optimize BroadcastDataLoader bf16 fp16

* rm InT etc. after merge develop

* delete constexpr for windows ci

* fix conflict

* fix conflic with develop

* fix conflic

* new clean

* clean
上级 a53ee944
...@@ -19,17 +19,13 @@ ...@@ -19,17 +19,13 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <ElementwiseType ET, template <typename OutT, typename Functor, int NumOuts = 1>
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchElementwiseCudaKernel( void LaunchElementwiseCudaKernel(
const KPDevice &ctx, const KPDevice &ctx,
const std::vector<const phi::DenseTensor *> &ins, const std::vector<const phi::DenseTensor *> &ins,
std::vector<phi::DenseTensor *> *outs, std::vector<phi::DenseTensor *> *outs,
int axis, Functor func,
Functor func) { int axis = -1) {
std::vector<const phi::DenseTensor *> pt_inputs; std::vector<const phi::DenseTensor *> pt_inputs;
std::vector<phi::DenseTensor *> pt_outputs; std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
...@@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel( ...@@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) { for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get()); pt_outputs.push_back(pt_outputs_tmp[i].get());
} }
phi::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>( phi::funcs::BroadcastKernel<OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func); ctx, pt_inputs, &pt_outputs, func, axis);
} }
} // namespace operators } // namespace operators
......
...@@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
z->mutable_data<OutType>(ctx.GetPlace()); z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx = ctx.template device_context<DeviceContext>(); const auto &dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::ElementwiseCompute<Functor, T, OutType>( phi::funcs::ElementwiseCompute<Functor, T, OutType>(
dev_ctx, *x, *y, axis, func, z); dev_ctx, *x, *y, func, z, axis);
} }
// FusedElemwiseAndAct // FusedElemwiseAndAct
...@@ -1596,7 +1596,7 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in, ...@@ -1596,7 +1596,7 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
template <ElementwiseType ET, typename T, typename Functor> template <typename T, typename Functor>
void GetGradXAndYOut(const phi::GPUContext &dev_ctx, void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place, const platform::Place &place,
int axis, int axis,
...@@ -1605,11 +1605,11 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx, ...@@ -1605,11 +1605,11 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
phi::DenseTensor *dx, phi::DenseTensor *dx,
phi::DenseTensor *dy, phi::DenseTensor *dy,
Functor func) { Functor func) {
phi::GetGradXAndYOut<ET, T, Functor>( phi::GetGradXAndYOut<T, Functor>(
dev_ctx, place, axis, ins, *dout, dx, dy, func); dev_ctx, place, axis, ins, *dout, dx, dy, func);
} }
template <ElementwiseType ET, typename T, typename Functor> template <typename T, typename Functor>
void GetGradXOrYOut(const phi::GPUContext &dev_ctx, void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place, const platform::Place &place,
int axis, int axis,
...@@ -1617,8 +1617,7 @@ void GetGradXOrYOut(const phi::GPUContext &dev_ctx, ...@@ -1617,8 +1617,7 @@ void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
const phi::DenseTensor *dout, const phi::DenseTensor *dout,
phi::DenseTensor *dxy, phi::DenseTensor *dxy,
Functor func) { Functor func) {
phi::GetGradXOrYOut<ET, T, Functor>( phi::GetGradXOrYOut<T, Functor>(dev_ctx, place, axis, ins, *dout, dxy, func);
dev_ctx, place, axis, ins, *dout, dxy, func);
} }
#endif #endif
......
...@@ -23,8 +23,6 @@ limitations under the License. */ ...@@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using ElementwiseType = phi::ElementwiseType;
template <typename OutT, typename Functor, int NumOuts = 1> template <typename OutT, typename Functor, int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel( void LaunchSameDimsElementwiseCudaKernel(
const KPDevice &ctx, const KPDevice &ctx,
......
...@@ -109,8 +109,8 @@ class AttnMatMul { ...@@ -109,8 +109,8 @@ class AttnMatMul {
// bias_out = output + bias // bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias}; std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out}; std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
} }
} }
......
...@@ -85,8 +85,8 @@ class AttnMatmulINT8 { ...@@ -85,8 +85,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias // bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias}; std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out}; std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(), PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess, cudaSuccess,
platform::errors::Fatal( platform::errors::Fatal(
...@@ -139,8 +139,8 @@ class AttnMatmulINT8 { ...@@ -139,8 +139,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias // bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias}; std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out}; std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(), PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess, cudaSuccess,
platform::errors::Fatal( platform::errors::Fatal(
......
...@@ -255,12 +255,11 @@ class FMHARef { ...@@ -255,12 +255,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor); ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor); outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1; int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx_,
dev_ctx_, ins,
ins, &outs,
&outs, phi::funcs::AddFunctor<T>(),
elewise_add_axis, elewise_add_axis);
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>( phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
...@@ -432,12 +431,11 @@ class FMHARef { ...@@ -432,12 +431,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor); ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor); outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1; int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx_,
dev_ctx_, ins,
ins, &outs,
&outs, phi::funcs::AddFunctor<T>(),
elewise_add_axis, elewise_add_axis);
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>( phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......
...@@ -689,13 +689,13 @@ class FMHAGateRef { ...@@ -689,13 +689,13 @@ class FMHAGateRef {
std::vector<const phi::DenseTensor*> ins = { std::vector<const phi::DenseTensor*> ins = {
qk_out, src_mask, nonbatched_bias}; qk_out, src_mask, nonbatched_bias};
std::vector<phi::DenseTensor*> outs = {qk_out}; std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>()); dev_ctx_, ins, &outs, TernaryAddFunctor<T>());
} else { } else {
std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask}; std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask};
std::vector<phi::DenseTensor*> outs = {qk_out}; std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
} }
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out); phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
} }
......
...@@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> { ...@@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
ins.emplace_back(attn); ins.emplace_back(attn);
ins.emplace_back(mask); ins.emplace_back(mask);
outs.emplace_back(&attn_tmp); outs.emplace_back(&attn_tmp);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<T>(dev_ctx, ins, &outs, AttnMaskFunctor<T>());
dev_ctx, ins, &outs, -1, AttnMaskFunctor<T>());
// 2. Reduce sum // 2. Reduce sum
const std::vector<int64_t> reduce_dims{1, 2}; const std::vector<int64_t> reduce_dims{1, 2};
......
...@@ -834,12 +834,11 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> { ...@@ -834,12 +834,11 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
} }
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>( phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
dev_ctx, pt_d_out.get(),
pt_d_out.get(), pt_d_x.get(),
pt_d_x.get(), pt_out_dtype,
pt_out_dtype, TransformOp<T, MPType>(reduce_num));
TransformOp<T, MPType>(reduce_num));
} }
}; };
......
...@@ -24,15 +24,15 @@ limitations under the License. */ ...@@ -24,15 +24,15 @@ limitations under the License. */
namespace phi { namespace phi {
#define DEFINE_BITWISE_KERNEL(op_type) \ #define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \ void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
DenseTensor* out) { \ DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \ funcs::Bitwise##op_type##Functor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T, T>( \ funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
dev_ctx, x, y, -1, func, out); \ dev_ctx, x, y, func, out); \
} }
DEFINE_BITWISE_KERNEL(And) DEFINE_BITWISE_KERNEL(And)
......
...@@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx, ...@@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx,
ctx.template Alloc<bool>(out); ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) { if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>( funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, axis, Functor(), out); ctx, x, y, Functor(), out, axis);
} else { } else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>( funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, axis, InverseFunctor(), out); ctx, x, y, InverseFunctor(), out, axis);
} }
} }
...@@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx,
tmp_data[0] = Functor()(x.data<T>()[0], y.data<T>()[0]); tmp_data[0] = Functor()(x.data<T>()[0], y.data<T>()[0]);
} else { } else {
funcs::ElementwiseCompute<Functor, T, bool>( funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, 0, Functor(), &tmp); ctx, x, y, Functor(), &tmp, 0);
} }
auto tmp_flat = EigenVector<bool>::Flatten(tmp); auto tmp_flat = EigenVector<bool>::Flatten(tmp);
auto out_es = EigenScalar<bool>::From(*out); auto out_es = EigenScalar<bool>::From(*out);
......
...@@ -91,8 +91,8 @@ struct DirichletSampler<CPUContext, T> { ...@@ -91,8 +91,8 @@ struct DirichletSampler<CPUContext, T> {
true, true,
false); false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out); dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
} }
}; };
......
...@@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx, ...@@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out); dev_ctx, x, y, funcs::DivideFunctor<T>(), out, axis);
} else { } else {
funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out); dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, axis);
} }
} }
} }
......
...@@ -75,7 +75,7 @@ void HeavisideKernel(const Context& dev_ctx, ...@@ -75,7 +75,7 @@ void HeavisideKernel(const Context& dev_ctx,
// allocate memory for out // allocate memory for out
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
dev_ctx, x, y, -1, funcs::ElementwiseHeavisideFunctor<T>(), out); dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
} }
} // namespace phi } // namespace phi
......
...@@ -68,20 +68,15 @@ void LayerNormGradKernel(const Context& dev_ctx, ...@@ -68,20 +68,15 @@ void LayerNormGradKernel(const Context& dev_ctx,
temp_norm.Resize(matrix_shape); temp_norm.Resize(matrix_shape);
dev_ctx.template Alloc<T>(&temp_norm); dev_ctx.template Alloc<T>(&temp_norm);
// get x_norm // get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, dev_ctx, x_tmp, mean, funcs::SubtractFunctor<T>(), &temp_norm, 0);
x_tmp, phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
mean,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
&temp_norm);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
dev_ctx, dev_ctx,
temp_norm, temp_norm,
variance, variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&temp_norm); &temp_norm,
0);
} }
if (d_bias) { if (d_bias) {
...@@ -90,8 +85,8 @@ void LayerNormGradKernel(const Context& dev_ctx, ...@@ -90,8 +85,8 @@ void LayerNormGradKernel(const Context& dev_ctx,
} }
if (d_scale) { if (d_scale) {
dev_ctx.template Alloc<T>(d_scale); dev_ctx.template Alloc<T>(d_scale);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor<T>(), &temp); dev_ctx, temp_norm, d_y, funcs::MultiplyFunctor<T>(), &temp, 0);
colwise_sum(dev_ctx, temp, d_scale); colwise_sum(dev_ctx, temp, d_scale);
} }
...@@ -107,70 +102,45 @@ void LayerNormGradKernel(const Context& dev_ctx, ...@@ -107,70 +102,45 @@ void LayerNormGradKernel(const Context& dev_ctx,
if (d_scale) { if (d_scale) {
// dy_dx // dy_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor<T>(), &temp); dev_ctx, d_y, *scale, funcs::MultiplyFunctor<T>(), &temp, 1);
phi::Copy<Context>(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x); phi::Copy<Context>(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx // dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec); row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor<T>(), d_x, 0);
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
// dy_var_dx // dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, dev_ctx, temp, temp_norm, funcs::MultiplyFunctor<T>(), &temp, 0);
temp,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
} else { } else {
// dy_dx // dy_dx
phi::Copy<Context>(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x); phi::Copy<Context>(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx // dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec); row_mean(dev_ctx, d_y, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, dev_ctx, *d_x, temp_vec, funcs::SubtractFunctor<T>(), d_x, 0);
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
// dy_var_dx // dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, dev_ctx, d_y, temp_norm, funcs::MultiplyFunctor<T>(), &temp, 0);
d_y,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
} }
// dy_var_dx // dy_var_dx
row_mean(dev_ctx, temp, &temp_vec); row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, dev_ctx, temp_norm, temp_vec, funcs::MultiplyFunctor<T>(), &temp, 0);
temp_norm, phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
temp_vec, dev_ctx, *d_x, temp, funcs::SubtractFunctor<T>(), d_x, 0);
/*axis*/ 0,
funcs::MultiplyFunctor<T>(), phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
&temp);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor<T>(), d_x);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
dev_ctx, dev_ctx,
*d_x, *d_x,
variance, variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
d_x); d_x,
0);
d_x->Resize(dx_dim); d_x->Resize(dx_dim);
} }
} }
......
...@@ -67,30 +67,30 @@ void LayerNormKernel(const Context& dev_ctx, ...@@ -67,30 +67,30 @@ void LayerNormKernel(const Context& dev_ctx,
// get variance // get variance
phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor<T>(), &out); dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor<T>(), &out, 0);
row_mean(dev_ctx, out, var); row_mean(dev_ctx, out, var);
// get x_norm // get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor<T>(), &out); dev_ctx, x_tmp, *mean, funcs::SubtractFunctor<T>(), &out, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx, dev_ctx,
out, out,
*var, *var,
0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&out); &out,
0);
if (scale) { if (scale) {
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, out, *scale, 1, funcs::MultiplyFunctor<T>(), &out); dev_ctx, out, *scale, funcs::MultiplyFunctor<T>(), &out, 1);
} }
if (bias) { if (bias) {
phi::funcs::ElementwiseCompute<funcs::AddFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, out, *bias, 1, funcs::AddFunctor<T>(), &out); dev_ctx, out, *bias, funcs::AddFunctor<T>(), &out, 1);
} }
#else #else
PADDLE_ENFORCE_EQ(mean->numel(), PADDLE_ENFORCE_EQ(mean->numel(),
......
...@@ -32,7 +32,7 @@ namespace phi { ...@@ -32,7 +32,7 @@ namespace phi {
DenseTensor* out) { \ DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \ funcs::Logical##type##Functor<T> binary_func; \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \ funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, -1, binary_func, out); \ dev_ctx, x, y, binary_func, out); \
} }
DEFINE_LOGICAL_BINARY_KERNEL(And) DEFINE_LOGICAL_BINARY_KERNEL(And)
......
...@@ -132,11 +132,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -132,11 +132,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
DenseTensor tol_tensor; DenseTensor tol_tensor;
tol_tensor.Resize(dim_out); tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor); dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>( funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
dev_ctx, dev_ctx,
atol_tensor, atol_tensor,
rtol_tensor, rtol_tensor,
-1,
GreaterElementFunctor<T>(), GreaterElementFunctor<T>(),
&tol_tensor); &tol_tensor);
...@@ -151,17 +150,17 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -151,17 +150,17 @@ void MatrixRankTolKernel(const Context& dev_ctx,
dev_ctx, dev_ctx,
eigenvalue_tensor, eigenvalue_tensor,
tol_tensor, tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(), funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result); &compare_result,
axis);
} else { } else {
funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>( funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>(
dev_ctx, dev_ctx,
eigenvalue_tensor, eigenvalue_tensor,
tol_tensor, tol_tensor,
axis,
funcs::LessThanFunctor<T, int64_t>(), funcs::LessThanFunctor<T, int64_t>(),
&compare_result); &compare_result,
axis);
} }
phi::SumKernel<int64_t>(dev_ctx, phi::SumKernel<int64_t>(dev_ctx,
......
...@@ -189,45 +189,29 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> { ...@@ -189,45 +189,29 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
} }
}; };
// Common broadcast data loader. template <int Index, int VecSize>
template <int Index, int VecSize, bool IsBoundary> struct BroadcastDataInit {
struct BroadcastDataLoader<Index, VecSize, IsBoundary, kBroadcast> { template <typename ArgsT>
template <typename Array1, typename Array2, typename Array3, typename ArgsT> static __device__ __forceinline__ void Apply(ArgsT *args) {
static __device__ __forceinline__ void Apply(const Array1 &ins,
ArgsT *args,
const Array2 &configs,
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
using Type = std::tuple_element_t<Index, ArgsT>; using Type = std::tuple_element_t<Index, ArgsT>;
uint32_t index_bc[VecSize];
#pragma unroll #pragma unroll
for (int k = 0; k < VecSize; ++k) { for (int k = 0; k < VecSize; ++k) {
index_bc[k] = 0;
std::get<Index>(args[k]) = static_cast<Type>(1); std::get<Index>(args[k]) = static_cast<Type>(1);
} }
}
};
uint32_t thread_offset = block_offset + threadIdx.x * VecSize; template <int Index, int VecSize>
#pragma unroll struct BroadcastDataSetter {
for (int k = 0; k < VecSize; ++k) { template <typename Array, typename ArgsT>
uint32_t idx = thread_offset + k; static __device__ __forceinline__ void Apply(const Array &ins,
if (IsBoundary && idx == numel) { ArgsT *args,
break; uint32_t index_bc[][VecSize]) {
} using Type = std::tuple_element_t<Index, ArgsT>;
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].rank) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i];
}
}
#pragma unroll #pragma unroll
for (int k = 0; k < VecSize; ++k) { for (int k = 0; k < VecSize; ++k) {
std::get<Index>(args[k]) = std::get<Index>(args[k]) =
reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[k]]; reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[Index][k]];
} }
} }
}; };
...@@ -285,8 +269,30 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -285,8 +269,30 @@ __device__ void VectorizedBroadcastKernelImpl(
__simd__ ArgsT args[VecSize]; __simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize]; __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step( if (LoadType == kBroadcast) {
ins, args, configs, use_broadcast, block_offset, num, numel); uint32_t index_bc[Arity][VecSize] = {0};
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k;
if (IsBoundary && idx == numel) break;
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].rank) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
}
}
Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
} else {
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel);
}
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>, SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
VecSize, VecSize,
...@@ -783,11 +789,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<OutT, ...@@ -783,11 +789,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
}; };
#endif #endif
template <ElementwiseType ET, template <typename OutT, typename Functor, int kArity, int NumOuts = 1>
typename OutT,
typename Functor,
int kArity,
int NumOuts = 1>
void BroadcastKernelForDifferentVecSize( void BroadcastKernelForDifferentVecSize(
const KPDevice &ctx, const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
...@@ -922,16 +924,12 @@ void BroadcastKernelForDifferentVecSize( ...@@ -922,16 +924,12 @@ void BroadcastKernelForDifferentVecSize(
} }
} }
template <ElementwiseType ET, template <typename OutT, typename Functor, int NumOuts = 1>
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx, void BroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, Functor func,
Functor func) { int axis = -1) {
// When there are multiple inputs, the outputs's rank should be equal the // When there are multiple inputs, the outputs's rank should be equal the
// maximum rank of all inputs. // maximum rank of all inputs.
using Traits = phi::funcs::FunctionTraits<Functor>; using Traits = phi::funcs::FunctionTraits<Functor>;
...@@ -968,7 +966,7 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -968,7 +966,7 @@ void BroadcastKernel(const KPDevice &ctx,
max_rank = std::max(max_rank, (*outs)[0]->dims().size()); max_rank = std::max(max_rank, (*outs)[0]->dims().size());
} }
axis = axis == -1 ? max_rank - min_rank : axis; axis = axis == -1 ? max_rank - min_rank : axis;
BroadcastKernelForDifferentVecSize<ET, OutT, Functor, kArity, NumOuts>( BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
ctx, ins, outs, axis, func); ctx, ins, outs, axis, func);
} }
...@@ -976,15 +974,14 @@ template <typename Functor, typename T, typename OutType = T> ...@@ -976,15 +974,14 @@ template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx, void ElementwiseCompute(const GPUContext &dev_ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &y, const DenseTensor &y,
int axis,
Functor func, Functor func,
DenseTensor *z) { DenseTensor *z,
int axis = -1) {
std::vector<const DenseTensor *> ins = {&x, &y}; std::vector<const DenseTensor *> ins = {&x, &y};
std::vector<DenseTensor *> outs = {z}; std::vector<DenseTensor *> outs = {z};
dev_ctx.template Alloc<OutType>(z); dev_ctx.template Alloc<OutType>(z);
BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>( BroadcastKernel<OutType, Functor, 1>(dev_ctx, ins, &outs, func, axis);
dev_ctx, ins, &outs, axis, func);
} }
template <typename DeviceContext, template <typename DeviceContext,
...@@ -999,7 +996,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx, ...@@ -999,7 +996,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
dev_ctx.template Alloc<T>(z); dev_ctx.template Alloc<T>(z);
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z); funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
} }
#else #else
...@@ -1017,10 +1014,10 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx, ...@@ -1017,10 +1014,10 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
dev_ctx.template Alloc<T>(z); dev_ctx.template Alloc<T>(z);
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z); funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
} else { } else {
funcs::ElementwiseCompute<InverseFunctor, T>( funcs::ElementwiseCompute<InverseFunctor, T>(
dev_ctx, x, y, axis, InverseFunctor(), z); dev_ctx, x, y, InverseFunctor(), z, axis);
} }
} }
#endif #endif
......
...@@ -191,25 +191,19 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -191,25 +191,19 @@ __global__ void VectorizedRandomGenerator(const size_t n,
} }
template <typename T> template <typename T>
__global__ void DropOutNdForwardKernel( __global__ void VectorizedGeneratorMask(const size_t n,
const size_t n, uint64_t seed,
uint64_t seed, const float dropout_prob,
const float dropout_prob, const T* src,
const T* src, uint8_t* mask,
uint8_t* mask, uint64_t increment,
uint64_t increment, size_t main_offset,
size_t main_offset, MaskFunctor<T> mask_functor,
DstFunctor<T> dst_functor,
MaskFunctor<T> mask_functor, const uint64_t* seed_ptr) {
T* y,
int64_t N,
kps::details::BroadcastConfig broadcast_config,
const uint64_t* seed_ptr) {
// Vectorized Generate Mask // Vectorized Generate Mask
// kCount is 4 for curand_uniform4 is used // kCount is 4 for curand_uniform4 is used
if (seed_ptr) { if (seed_ptr) seed = seed_ptr[0];
seed = seed_ptr[0];
}
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount; constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X); size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
...@@ -259,22 +253,6 @@ __global__ void DropOutNdForwardKernel( ...@@ -259,22 +253,6 @@ __global__ void DropOutNdForwardKernel(
kps::WriteData<uint8_t, kCount, 1, true>( kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder); mask + fix, &mask_result[0], remainder);
} }
// Broadcast mask data and do elementwise operaiton with DstFunctor
CUDA_KERNEL_LOOP(i, N) {
uint32_t offset = 0u;
uint32_t idx = i;
// Use (j < phi::DDim::kMaxRank) conditiion rather than
// (j < broadcast_config.rank) for (#pragma unroll)
#pragma unroll
for (int j = 0; j < phi::DDim::kMaxRank; ++j) {
if (j == broadcast_config.rank) break;
auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx);
idx = fast_divmoder.val[0];
offset += broadcast_config.strides[j] * fast_divmoder.val[1];
}
__syncthreads();
y[i] = dst_functor(src[i], mask[offset]);
}
} }
template <typename T, typename MT> template <typename T, typename MT>
...@@ -348,18 +326,6 @@ void DropoutFwGPUKernelDriver( ...@@ -348,18 +326,6 @@ void DropoutFwGPUKernelDriver(
size / (block_size * kVecSize) * (block_size * kVecSize); size / (block_size * kVecSize) * (block_size * kVecSize);
if (is_dropout_nd) { if (is_dropout_nd) {
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<int64_t> out_dims =
std::move(phi::vectorize<int64_t>(x.dims()));
std::vector<int64_t> in_dims =
std::move(phi::vectorize<int64_t>(mask->dims()));
std::reverse(out_dims.begin(), out_dims.end());
std::reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config(
out_dims, in_dims, x.dims().size());
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob); auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx, bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
seed, seed,
...@@ -372,20 +338,22 @@ void DropoutFwGPUKernelDriver( ...@@ -372,20 +338,22 @@ void DropoutFwGPUKernelDriver(
const uint64_t* seed_ptr = const uint64_t* seed_ptr =
copy_in_kernel ? seed->data<uint64_t>() : nullptr; copy_in_kernel ? seed->data<uint64_t>() : nullptr;
DropOutNdForwardKernel<T> VectorizedGeneratorMask<T>
<<<grid_size, block_size, 0, stream>>>(size, <<<grid_size, block_size, 0, stream>>>(size,
seed_data, seed_data,
dropout_prob, dropout_prob,
x_data, x_data,
mask_data, mask_data,
increment, increment,
main_offset, main_offset,
dst_functor,
mask_functor, mask_functor,
y_data,
y->numel(),
broadcast_config,
seed_ptr); seed_ptr);
auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
std::vector<const phi::DenseTensor*> ins = {&x, mask};
std::vector<phi::DenseTensor*> outs = {y};
phi::funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, dst_functor);
} else { } else {
bool copy_in_kernel = GetSeedDataAndIncrement( bool copy_in_kernel = GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
...@@ -469,30 +437,13 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -469,30 +437,13 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
MT factor = upscale_in_train MT factor = upscale_in_train
? static_cast<MT>(1.0f / (1.0f - dropout_prob)) ? static_cast<MT>(1.0f / (1.0f - dropout_prob))
: static_cast<MT>(1.0f); : static_cast<MT>(1.0f);
std::vector<const phi::DenseTensor*> ins = {&grad_y, &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
if (is_dropout_nd) { if (is_dropout_nd) {
phi::DenseTensor broadcasted_mask; phi::funcs::BroadcastKernel<T>(
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
std::vector<const phi::DenseTensor*> ins = {&grad_y, &broadcasted_mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor)); dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} else { } else {
std::vector<const phi::DenseTensor*> ins = {&grad_y, &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor)); dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} }
......
...@@ -35,7 +35,6 @@ namespace kps = phi::kps; ...@@ -35,7 +35,6 @@ namespace kps = phi::kps;
namespace phi { namespace phi {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type /* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/ for supporting multiple-output feature in elementwise system.*/
template <class T, int Num> template <class T, int Num>
...@@ -369,9 +368,9 @@ template <typename Functor, typename T, typename OutType = T> ...@@ -369,9 +368,9 @@ template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const CPUContext &dev_ctx, void ElementwiseCompute(const CPUContext &dev_ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &y, const DenseTensor &y,
int axis,
Functor func, Functor func,
DenseTensor *z) { DenseTensor *z,
int axis = -1) {
dev_ctx.Alloc<OutType>(z); dev_ctx.Alloc<OutType>(z);
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
......
...@@ -112,8 +112,8 @@ class AttnMatMul { ...@@ -112,8 +112,8 @@ class AttnMatMul {
// bias_out = output + bias // bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias}; std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out}; std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
} }
} }
......
...@@ -258,12 +258,11 @@ class FMHARef { ...@@ -258,12 +258,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor); ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor); outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1; int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx_,
dev_ctx_, ins,
ins, &outs,
&outs, phi::funcs::AddFunctor<T>(),
elewise_add_axis, elewise_add_axis);
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>( phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
...@@ -435,12 +434,11 @@ class FMHARef { ...@@ -435,12 +434,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor); ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor); outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1; int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx_,
dev_ctx_, ins,
ins, &outs,
&outs, phi::funcs::AddFunctor<T>(),
elewise_add_axis, elewise_add_axis);
phi::funcs::AddFunctor<T>());
phi::SoftmaxForwardCUDAKernelDriver<T>( phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
......
...@@ -106,8 +106,8 @@ struct DirichletSampler<GPUContext, T> { ...@@ -106,8 +106,8 @@ struct DirichletSampler<GPUContext, T> {
{new_shape.size() - 1}, {new_shape.size() - 1},
true, true,
false); false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out); dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
} }
}; };
} // namespace phi } // namespace phi
......
...@@ -37,22 +37,21 @@ void DivideGradKernel(const Context& dev_ctx, ...@@ -37,22 +37,21 @@ void DivideGradKernel(const Context& dev_ctx,
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y}; std::vector<const DenseTensor*> ins = {&dout, &out, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<T>(dev_ctx,
dev_ctx, place,
place, axis,
axis, ins,
ins, dout,
dout, dx,
dx, dy,
dy, funcs::DivGradXYFunctor<T, T>());
funcs::DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) { } else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &y}; std::vector<const DenseTensor*> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) { } else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y}; std::vector<const DenseTensor*> ins = {&dout, &out, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
} }
} }
......
...@@ -35,7 +35,7 @@ void ReduceWrapper(const GPUContext &dev_ctx, ...@@ -35,7 +35,7 @@ void ReduceWrapper(const GPUContext &dev_ctx,
dev_ctx, *src, dst, kps::IdentityFunctor<T>(), reduce_dims); dev_ctx, *src, dst, kps::IdentityFunctor<T>(), reduce_dims);
} }
template <ElementwiseType ET, typename T, typename Functor> template <typename T, typename Functor>
void GetGradXAndYOut(const GPUContext &dev_ctx, void GetGradXAndYOut(const GPUContext &dev_ctx,
const Place &place, const Place &place,
int axis, int axis,
...@@ -67,8 +67,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx, ...@@ -67,8 +67,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx,
outs = {&tmp_dx, &tmp_dy}; outs = {&tmp_dx, &tmp_dy};
} }
funcs::BroadcastKernel<ET, T, T, decltype(func), 2>( funcs::BroadcastKernel<T, decltype(func), 2>(dev_ctx, ins, &outs, func, axis);
dev_ctx, ins, &outs, axis, func);
if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) { if (dx->dims() != dout.dims() && dy->dims() == dout.dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx); ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
...@@ -80,7 +79,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx, ...@@ -80,7 +79,7 @@ void GetGradXAndYOut(const GPUContext &dev_ctx,
} }
} }
template <ElementwiseType ET, typename T, typename Functor> template <typename T, typename Functor>
void GetGradXOrYOut(const GPUContext &dev_ctx, void GetGradXOrYOut(const GPUContext &dev_ctx,
const Place &place, const Place &place,
int axis, int axis,
...@@ -100,7 +99,7 @@ void GetGradXOrYOut(const GPUContext &dev_ctx, ...@@ -100,7 +99,7 @@ void GetGradXOrYOut(const GPUContext &dev_ctx,
outs = {dxy}; outs = {dxy};
} }
funcs::BroadcastKernel<ET, T, T>(dev_ctx, ins, &outs, axis, func); funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func, axis);
if (dxy->dims() != dout.dims()) { if (dxy->dims() != dout.dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dxy, dxy); ReduceWrapper<T>(dev_ctx, axis, &tmp_dxy, dxy);
} }
...@@ -342,22 +341,21 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, ...@@ -342,22 +341,21 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx,
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &out, &y}; std::vector<const DenseTensor *> ins = {&dout, &out, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<T>(dev_ctx,
dev_ctx, place,
place, axis,
axis, ins,
ins, dout,
dout, dx,
dx, dy,
dy, funcs::DivGradXYFunctor<T, T>());
funcs::DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) { } else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y}; std::vector<const DenseTensor *> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) { } else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &out, &y}; std::vector<const DenseTensor *> ins = {&dout, &out, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
} }
} }
...@@ -380,22 +378,21 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx, ...@@ -380,22 +378,21 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx,
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y, &x}; std::vector<const DenseTensor *> ins = {&dout, &y, &x};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<T>(dev_ctx,
dev_ctx, place,
place, axis,
axis, ins,
ins, dout,
dout, dx,
dx, dy,
dy, funcs::MultiplyGradXYFunctor<T, T>());
funcs::MultiplyGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) { } else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &y}; std::vector<const DenseTensor *> ins = {&dout, &y};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor<T>());
} else if (dx == nullptr && dy != nullptr) { } else if (dx == nullptr && dy != nullptr) {
std::vector<const DenseTensor *> ins = {&dout, &x}; std::vector<const DenseTensor *> ins = {&dout, &x};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor<T>());
} }
} }
......
...@@ -37,22 +37,21 @@ void MaximumGradKernel(const Context& dev_ctx, ...@@ -37,22 +37,21 @@ void MaximumGradKernel(const Context& dev_ctx,
int axis = -1; int axis = -1;
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<T>(dev_ctx,
dev_ctx, place,
place, axis,
axis, ins,
ins, dout,
dout, dx,
dx, dy,
dy, funcs::MaxGradXYFunctor<T, T>());
funcs::MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) { } else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MaxGradXFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::MaxGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) { } else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MaxGradYFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::MaxGradYFunctor<T>());
} }
} }
...@@ -68,22 +67,21 @@ void MinimumGradKernel(const Context& dev_ctx, ...@@ -68,22 +67,21 @@ void MinimumGradKernel(const Context& dev_ctx,
int axis = -1; int axis = -1;
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<T>(dev_ctx,
dev_ctx, place,
place, axis,
axis, ins,
ins, dout,
dout, dx,
dx, dy,
dy, funcs::MinGradXYFunctor<T, T>());
funcs::MinGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) { } else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MinGradXFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::MinGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) { } else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>( GetGradXOrYOut<T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MinGradYFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::MinGradYFunctor<T>());
} }
} }
......
...@@ -74,8 +74,7 @@ void ExpandAsKernel(const Context& ctx, ...@@ -74,8 +74,7 @@ void ExpandAsKernel(const Context& ctx,
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x}; std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>( phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
ctx, ins, &outs, -1, kps::IdentityFunctor<T>());
} }
} // namespace phi } // namespace phi
......
...@@ -73,8 +73,7 @@ void ExpandKernel(const Context& ctx, ...@@ -73,8 +73,7 @@ void ExpandKernel(const Context& ctx,
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x}; std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>( phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
ctx, ins, &outs, -1, kps::IdentityFunctor<T>());
} }
} // namespace phi } // namespace phi
......
...@@ -407,11 +407,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -407,11 +407,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
tol_tensor.Resize(dim_out); tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor); dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>( funcs::ElementwiseCompute<GreaterElementFunctor<T>, T>(
dev_ctx, dev_ctx,
atol_tensor, atol_tensor,
rtol_tensor, rtol_tensor,
-1,
GreaterElementFunctor<T>(), GreaterElementFunctor<T>(),
&tol_tensor); &tol_tensor);
...@@ -421,12 +420,10 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -421,12 +420,10 @@ void MatrixRankTolKernel(const Context& dev_ctx,
compare_result.Resize(detail::NewAxisDim(dim_out, k)); compare_result.Resize(detail::NewAxisDim(dim_out, k));
dev_ctx.template Alloc<int64_t>(&compare_result); dev_ctx.template Alloc<int64_t>(&compare_result);
int axis = -1;
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int64_t>( funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int64_t>(
dev_ctx, dev_ctx,
eigenvalue_tensor, eigenvalue_tensor,
tol_tensor, tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(), funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result); &compare_result);
......
...@@ -78,8 +78,8 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, ...@@ -78,8 +78,8 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
// 1. equal_out = Equal(x, y) // 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_y, new_in_tensor}; std::vector<const phi::DenseTensor*> equal_inputs = {&new_y, new_in_tensor};
std::vector<phi::DenseTensor*> equal_outputs = {&equal_out_tensor}; std::vector<phi::DenseTensor*> equal_outputs = {&equal_out_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>()); dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. equal_count = reduceSum(equal_out) // 2. equal_count = reduceSum(equal_out)
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs:: phi::funcs::
...@@ -95,15 +95,15 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, ...@@ -95,15 +95,15 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
std::vector<const phi::DenseTensor*> mul_inputs = {&new_dout, std::vector<const phi::DenseTensor*> mul_inputs = {&new_dout,
&equal_out_tensor}; &equal_out_tensor};
std::vector<phi::DenseTensor*> mul_outputs = {&equal_out_tensor}; std::vector<phi::DenseTensor*> mul_outputs = {&equal_out_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>()); dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
// 4. dx = Div(dx, equal_out) // 4. dx = Div(dx, equal_out)
std::vector<const phi::DenseTensor*> grad_inputs = {&equal_out_tensor, std::vector<const phi::DenseTensor*> grad_inputs = {&equal_out_tensor,
equal_count}; equal_count};
std::vector<phi::DenseTensor*> grad_outputs = {new_dx_tensor}; std::vector<phi::DenseTensor*> grad_outputs = {new_dx_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, grad_inputs, &grad_outputs, 0, funcs::DivideFunctor<T>()); dev_ctx, grad_inputs, &grad_outputs, funcs::DivideFunctor<T>(), 0);
delete equal_out; delete equal_out;
delete equal_count; delete equal_count;
} }
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
namespace phi { namespace phi {
template <typename InT, typename Functor> template <typename Functor>
void ReduceGrad(const GPUContext& dev_ctx, void ReduceGrad(const GPUContext& dev_ctx,
DenseTensor* d_out, DenseTensor* d_out,
DenseTensor* d_x, DenseTensor* d_x,
...@@ -36,14 +36,13 @@ void ReduceGrad(const GPUContext& dev_ctx, ...@@ -36,14 +36,13 @@ void ReduceGrad(const GPUContext& dev_ctx,
Functor functor) { Functor functor) {
std::vector<const DenseTensor*> inputs = {d_out}; std::vector<const DenseTensor*> inputs = {d_out};
std::vector<DenseTensor*> outputs = {d_x}; std::vector<DenseTensor*> outputs = {d_x};
PD_VISIT_ALL_TYPES( PD_VISIT_ALL_TYPES(out_dtype, "BroadcastKernel", ([&] {
out_dtype, "BroadcastKernel", ([&] { funcs::BroadcastKernel<data_t>(
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, InT, data_t>( dev_ctx, inputs, &outputs, functor, 0);
dev_ctx, inputs, &outputs, 0, functor); }));
}));
} }
template <typename T, typename OutT, typename Context, typename Functor> template <typename OutT, typename Context, typename Functor>
void ReduceGradKernel(const Context& dev_ctx, void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad, const DenseTensor& out_grad,
...@@ -79,8 +78,7 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -79,8 +78,7 @@ void ReduceGradKernel(const Context& dev_ctx,
auto pt_d_x = *d_x; auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out}; std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x}; std::vector<DenseTensor*> outputs = {&pt_d_x};
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>( funcs::BroadcastKernel<OutT>(dev_ctx, inputs, &outputs, functor, 0);
dev_ctx, inputs, &outputs, 0, functor);
} }
} // namespace phi } // namespace phi
......
...@@ -62,14 +62,14 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -62,14 +62,14 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
// 1. equal_out = Equal(x, y) // 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x}; std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x};
std::vector<phi::DenseTensor*> equal_outputs = {equal_out}; std::vector<phi::DenseTensor*> equal_outputs = {equal_out};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>()); dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. dx = dout * 1 // 2. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out}; std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out};
std::vector<phi::DenseTensor*> mul_outputs = {x_grad}; std::vector<phi::DenseTensor*> mul_outputs = {x_grad};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>()); dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
delete equal_out; delete equal_out;
} }
} // namespace phi } // namespace phi
......
...@@ -53,8 +53,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -53,8 +53,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
std::vector<DenseTensor*> outputs = {x_grad}; std::vector<DenseTensor*> outputs = {x_grad};
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, 0, kps::DivideFunctor<T, MPType>(reduce_num)); dev_ctx, inputs, &outputs, kps::DivideFunctor<T, MPType>(reduce_num), 0);
} }
} // namespace phi } // namespace phi
......
...@@ -62,14 +62,14 @@ void ReduceMinGradKernel(const Context& dev_ctx, ...@@ -62,14 +62,14 @@ void ReduceMinGradKernel(const Context& dev_ctx,
// 1. equal_out = Equal(x, y) // 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x}; std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x};
std::vector<phi::DenseTensor*> equal_outputs = {equal_out}; std::vector<phi::DenseTensor*> equal_outputs = {equal_out};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>()); dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. dx = dout * 1 // 2. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out}; std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out};
std::vector<phi::DenseTensor*> mul_outputs = {x_grad}; std::vector<phi::DenseTensor*> mul_outputs = {x_grad};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>()); dev_ctx, mul_inputs, &mul_outputs, funcs::MultiplyFunctor<T>(), 0);
delete equal_out; delete equal_out;
} }
} // namespace phi } // namespace phi
......
...@@ -48,7 +48,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -48,7 +48,7 @@ void ReduceSumGradKernel(const Context& dev_ctx,
// call ReduceGrad // call ReduceGrad
dev_ctx.Alloc(x_grad, x.dtype()); dev_ctx.Alloc(x_grad, x.dtype());
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>( phi::ReduceGrad<kps::IdentityFunctor<T, MPType>>(
dev_ctx, dev_ctx,
&new_out_grad, &new_out_grad,
x_grad, x_grad,
......
...@@ -46,8 +46,7 @@ void SquaredL2NormGradKernel(const Context& dev_ctx, ...@@ -46,8 +46,7 @@ void SquaredL2NormGradKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> ins{&x, &dout}; std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx}; std::vector<DenseTensor*> outs{dx};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, phi::DoubleMulFunctor<T>());
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
} }
} // namespace phi } // namespace phi
......
...@@ -78,8 +78,8 @@ void TileKernel(const Context& dev_ctx, ...@@ -78,8 +78,8 @@ void TileKernel(const Context& dev_ctx,
tmp_out.Resize(make_ddim(vec_x_dims)); tmp_out.Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(&tmp_out); dev_ctx.template Alloc<T>(&tmp_out);
std::vector<DenseTensor*> outs = {&tmp_out}; std::vector<DenseTensor*> outs = {&tmp_out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>()); dev_ctx, ins, &outs, kps::IdentityFunctor<T>(), i);
tmp_out.Resize(out_dims); tmp_out.Resize(out_dims);
new_x = tmp_out; new_x = tmp_out;
} }
...@@ -89,8 +89,8 @@ void TileKernel(const Context& dev_ctx, ...@@ -89,8 +89,8 @@ void TileKernel(const Context& dev_ctx,
out->Resize(make_ddim(vec_x_dims)); out->Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>( phi::funcs::BroadcastKernel<T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>()); dev_ctx, ins, &outs, kps::IdentityFunctor<T>(), i);
out->Resize(out_dims); out->Resize(out_dims);
} }
} }
......
...@@ -91,8 +91,7 @@ struct BinaryOperation { ...@@ -91,8 +91,7 @@ struct BinaryOperation {
DenseTensor* output) { DenseTensor* output) {
std::vector<const DenseTensor*> ins{&lhs, &rhs}; std::vector<const DenseTensor*> ins{&lhs, &rhs};
std::vector<DenseTensor*> outs{output}; std::vector<DenseTensor*> outs{output};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, BinaryFunctor<T>(), 0);
dev_ctx, ins, &outs, 0, BinaryFunctor<T>());
} }
}; };
......
...@@ -90,16 +90,16 @@ void ComplexKernel(const Context& dev_ctx, ...@@ -90,16 +90,16 @@ void ComplexKernel(const Context& dev_ctx,
// facility functions // facility functions
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>( phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out); dev_ctx, x, y, RealAndImagToComplexFunctor<T>(), out);
#else #else
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>( phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out); dev_ctx, x, y, RealAndImagToComplexFunctor<T>(), out);
} else { } else {
phi::funcs::ElementwiseCompute<ImagAndRealToComplexFunctor<T>, T, C>( phi::funcs::ElementwiseCompute<ImagAndRealToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), out); dev_ctx, x, y, ImagAndRealToComplexFunctor<T>(), out);
} }
#endif #endif
} }
......
...@@ -76,15 +76,15 @@ void AddDoubleGradImpl(const Context& dev_ctx, ...@@ -76,15 +76,15 @@ void AddDoubleGradImpl(const Context& dev_ctx,
auto ddy_dims = ddy_safe.dims(); auto ddy_dims = ddy_safe.dims();
if (ddx_dims.size() >= ddy_dims.size()) { if (ddx_dims.size() >= ddy_dims.size()) {
funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>( funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor<T>(), ddout); dev_ctx, ddx_safe, ddy_safe, funcs::AddFunctor<T>(), ddout, axis);
} else { } else {
funcs::ElementwiseCompute<funcs::InverseAddFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseAddFunctor<T>, T>(
dev_ctx, dev_ctx,
ddx_safe, ddx_safe,
ddy_safe, ddy_safe,
axis,
funcs::InverseAddFunctor<T>(), funcs::InverseAddFunctor<T>(),
ddout); ddout,
axis);
} }
} }
} }
...@@ -107,7 +107,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx, ...@@ -107,7 +107,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx,
dev_ctx.template Alloc<T>(ddout); dev_ctx.template Alloc<T>(ddout);
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>( funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor<T>(), ddout); dev_ctx, ddx_safe, ddy_safe, funcs::SubtractFunctor<T>(), ddout, axis);
} }
} }
......
...@@ -39,10 +39,10 @@ namespace phi { ...@@ -39,10 +39,10 @@ namespace phi {
auto y_dims = y.dims(); \ auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \ if (x_dims.size() >= y_dims.size()) { \
funcs::ElementwiseCompute<funcs::name##Functor<T>, T>( \ funcs::ElementwiseCompute<funcs::name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::name##Functor<T>(), out); \ dev_ctx, x, y, funcs::name##Functor<T>(), out, axis); \
} else { \ } else { \
funcs::ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \ funcs::ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out); \ dev_ctx, x, y, funcs::Inverse##name##Functor<T>(), out, axis); \
} \ } \
} \ } \
} }
...@@ -62,8 +62,8 @@ namespace phi { ...@@ -62,8 +62,8 @@ namespace phi {
inputs.emplace_back(&y); \ inputs.emplace_back(&y); \
outputs.emplace_back(out); \ outputs.emplace_back(out); \
dev_ctx.template Alloc<T>(out); \ dev_ctx.template Alloc<T>(out); \
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( \ funcs::BroadcastKernel<T>( \
dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \ dev_ctx, inputs, &outputs, funcs::name##Functor<T>(), axis); \
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -72,8 +72,8 @@ void FMaxKernel(const Context& dev_ctx, ...@@ -72,8 +72,8 @@ void FMaxKernel(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T>(
dev_ctx, x, y, -1, funcs::FMaxFunctor<T>(), out); dev_ctx, x, y, funcs::FMaxFunctor<T>(), out);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -82,8 +82,8 @@ void FMinKernel(const Context& dev_ctx, ...@@ -82,8 +82,8 @@ void FMinKernel(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T>(
dev_ctx, x, y, -1, funcs::FMinFunctor<T>(), out); dev_ctx, x, y, funcs::FMinFunctor<T>(), out);
} }
} // namespace phi } // namespace phi
...@@ -153,12 +153,8 @@ void SetValueCompute(const Context& dev_ctx, ...@@ -153,12 +153,8 @@ void SetValueCompute(const Context& dev_ctx,
slice_tensor.Resize(slice_dims_for_assign); slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) { if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>(dev_ctx, phi::funcs::ElementwiseCompute<SubFunctor<T>, T>(
slice_tensor, dev_ctx, slice_tensor, *value_tensor, SubFunctor<T>(), &slice_tensor);
*value_tensor,
-1,
SubFunctor<T>(),
&slice_tensor);
} else { } else {
DenseTensor value_t(dtype); DenseTensor value_t(dtype);
auto value_dims = phi::make_ddim(shape); auto value_dims = phi::make_ddim(shape);
...@@ -166,8 +162,8 @@ void SetValueCompute(const Context& dev_ctx, ...@@ -166,8 +162,8 @@ void SetValueCompute(const Context& dev_ctx,
value_t.Resize(value_dims); value_t.Resize(value_dims);
dev_ctx.template Alloc<T>(&value_t); dev_ctx.template Alloc<T>(&value_t);
phi::funcs::ElementwiseCompute<SubFunctor<T>, T, T>( phi::funcs::ElementwiseCompute<SubFunctor<T>, T>(
dev_ctx, slice_tensor, value_t, -1, SubFunctor<T>(), &slice_tensor); dev_ctx, slice_tensor, value_t, SubFunctor<T>(), &slice_tensor);
} }
slice_tensor.Resize(slice_dims); slice_tensor.Resize(slice_dims);
......
...@@ -204,7 +204,6 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -204,7 +204,6 @@ void SetValueImpl(const Context& dev_ctx,
dev_ctx, dev_ctx,
slice_tensor, slice_tensor,
value, value,
-1,
funcs::SubtractFunctor<T>(), funcs::SubtractFunctor<T>(),
&slice_tensor); &slice_tensor);
} else { } else {
...@@ -212,7 +211,6 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -212,7 +211,6 @@ void SetValueImpl(const Context& dev_ctx,
dev_ctx, dev_ctx,
slice_tensor, slice_tensor,
value, value,
-1,
funcs::InverseSubtractFunctor<T>(), funcs::InverseSubtractFunctor<T>(),
&slice_tensor); &slice_tensor);
} }
......
...@@ -25,18 +25,17 @@ limitations under the License. */ ...@@ -25,18 +25,17 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi { namespace phi {
#define DEFINE_BITWISE_KERNEL(op_type) \ #define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \ void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
DenseTensor* out) { \ DenseTensor* out) { \
dev_ctx.template Alloc<T>(out); \ dev_ctx.template Alloc<T>(out); \
funcs::Bitwise##op_type##Functor<T> func; \ funcs::Bitwise##op_type##Functor<T> func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \ std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \ std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( \ funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
dev_ctx, ins, &outs, -1, func); \
} }
DEFINE_BITWISE_KERNEL(And) DEFINE_BITWISE_KERNEL(And)
......
...@@ -55,8 +55,7 @@ inline void CompareKernelImpl(const Context& ctx, ...@@ -55,8 +55,7 @@ inline void CompareKernelImpl(const Context& ctx,
ctx.template Alloc<bool>(out); ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y}; std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out}; std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, bool>( funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
ctx, ins, &outs, axis, Functor());
} }
#ifndef PADDLE_WITH_XPU_KP #ifndef PADDLE_WITH_XPU_KP
......
...@@ -72,8 +72,8 @@ void HeavisideKernel(const Context& dev_ctx, ...@@ -72,8 +72,8 @@ void HeavisideKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, -1, funcs::ElementwiseHeavisideFunctor<T>()); dev_ctx, inputs, &outputs, funcs::ElementwiseHeavisideFunctor<T>());
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -25,20 +25,17 @@ ...@@ -25,20 +25,17 @@
namespace phi { namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ #define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \ void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
DenseTensor* out) { \ DenseTensor* out) { \
using InT = typename funcs::Logical##type##Functor<T>::ELEMENT_TYPE; \ dev_ctx.template Alloc<bool>(out); \
using OutT = bool; \ funcs::Logical##type##Functor<T> binary_func; \
dev_ctx.template Alloc<bool>(out); \ std::vector<const DenseTensor*> ins = {&x, &y}; \
funcs::Logical##type##Functor<T> binary_func; \ std::vector<DenseTensor*> outs = {out}; \
std::vector<const DenseTensor*> ins = {&x, &y}; \ funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, InT, OutT>( \
dev_ctx, ins, &outs, -1, binary_func); \
} }
DEFINE_LOGICAL_BINARY_KERNEL(And) DEFINE_LOGICAL_BINARY_KERNEL(And)
...@@ -50,15 +47,11 @@ template <typename T, typename Context> ...@@ -50,15 +47,11 @@ template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx, void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
using InT = typename funcs::LogicalNotFunctor<T>::ELEMENT_TYPE;
using OutT = bool;
dev_ctx.template Alloc<bool>(out); dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func; funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x}; std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<ElementwiseType::kUnary, InT, OutT>( funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
dev_ctx, ins, &outs, -1, unary_func);
} }
} // namespace phi } // namespace phi
......
...@@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx, ...@@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx,
// allocate memory for out // allocate memory for out
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>( funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MaximumFunctor<T>(), out); dev_ctx, x, y, funcs::MaximumFunctor<T>(), out, axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx, ...@@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx,
// allocate memory for out // allocate memory for out
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>( funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MinimumFunctor<T>(), out); dev_ctx, x, y, funcs::MinimumFunctor<T>(), out, axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx, ...@@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>( funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::RemainderFunctor<T>(), out); dev_ctx, x, y, funcs::RemainderFunctor<T>(), out, axis);
} else { } else {
funcs::ElementwiseCompute<funcs::InverseRemainderFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseRemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseRemainderFunctor<T>(), out); dev_ctx, x, y, funcs::InverseRemainderFunctor<T>(), out, axis);
} }
} }
...@@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx, ...@@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::FloorDivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::FloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::FloorDivideFunctor<T>(), out); dev_ctx, x, y, funcs::FloorDivideFunctor<T>(), out, axis);
} else { } else {
funcs::ElementwiseCompute<funcs::InverseFloorDivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseFloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor<T>(), out); dev_ctx, x, y, funcs::InverseFloorDivideFunctor<T>(), out, axis);
} }
} }
...@@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx, ...@@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>( funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out); dev_ctx, x, y, funcs::ElementwisePowFunctor<T>(), out, axis);
} else { } else {
funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>( funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor<T>(), out); dev_ctx, x, y, funcs::ElementwiseInversePowFunctor<T>(), out, axis);
} }
} }
......
...@@ -36,8 +36,8 @@ void MaximumRawKernel(const Context& dev_ctx, ...@@ -36,8 +36,8 @@ void MaximumRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, axis, funcs::MaximumFunctor<T>()); dev_ctx, inputs, &outputs, funcs::MaximumFunctor<T>(), axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -54,8 +54,8 @@ void MinimumRawKernel(const Context& dev_ctx, ...@@ -54,8 +54,8 @@ void MinimumRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, axis, funcs::MinimumFunctor<T>()); dev_ctx, inputs, &outputs, funcs::MinimumFunctor<T>(), axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -72,8 +72,8 @@ void RemainderRawKernel(const Context& dev_ctx, ...@@ -72,8 +72,8 @@ void RemainderRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, axis, funcs::RemainderFunctor<T>()); dev_ctx, inputs, &outputs, funcs::RemainderFunctor<T>(), axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -90,8 +90,8 @@ void FloorDivideRawKernel(const Context& dev_ctx, ...@@ -90,8 +90,8 @@ void FloorDivideRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, axis, funcs::FloorDivideFunctor<T>()); dev_ctx, inputs, &outputs, funcs::FloorDivideFunctor<T>(), axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -108,8 +108,8 @@ void ElementwisePowRawKernel(const Context& dev_ctx, ...@@ -108,8 +108,8 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
inputs.emplace_back(&y); inputs.emplace_back(&y);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, axis, funcs::ElementwisePowFunctor<T>()); dev_ctx, inputs, &outputs, funcs::ElementwisePowFunctor<T>(), axis);
} }
} // namespace phi } // namespace phi
...@@ -174,4 +174,5 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, ...@@ -174,4 +174,5 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float16, float16,
int64_t, int64_t,
bfloat16) {} bfloat16) {}
#endif #endif
...@@ -89,8 +89,7 @@ void TestCase(const phi::GPUContext& dev_ctx, ...@@ -89,8 +89,7 @@ void TestCase(const phi::GPUContext& dev_ctx,
d_in1.get(), d_in2.get(), d_in3.get()}; d_in1.get(), d_in2.get(), d_in3.get()};
std::vector<phi::DenseTensor*> outputs{d_out.get()}; std::vector<phi::DenseTensor*> outputs{d_out.get()};
for (int i = 0; i < times; ++i) { for (int i = 0; i < times; ++i) {
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>( phi::funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, compute);
dev_ctx, inputs, &outputs, -1, compute);
} }
dev_ctx.Wait(); dev_ctx.Wait();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册