未验证 提交 52ef8656 编写于 作者: N niuliling123 提交者: GitHub

[cherry-pick]Delete ElementwiseKernel in BroadcastKernel (#42779) (#43210)

Delete ElementwiseKernel in BroadcastKernel
减少所有Broadcast中重复功能调用,同时减少编译时间和问题体积
上级 835a1888
...@@ -496,26 +496,16 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -496,26 +496,16 @@ void BroadcastKernel(const KPDevice &ctx,
Functor func) { Functor func) {
std::vector<int> dims_size; std::vector<int> dims_size;
dims_size.reserve(ins.size()); dims_size.reserve(ins.size());
bool no_broadcast_flag = true;
for (auto *in : ins) { for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size()); dims_size.emplace_back(in->dims().size());
} }
if (ins.size() > 0 && outs->size() > 0) { axis = axis == -1
no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims(); ? *std::max_element(dims_size.begin(), dims_size.end()) -
} *std::min_element(dims_size.begin(), dims_size.end())
: axis;
if (no_broadcast_flag) { BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(ctx, ins, outs, func); ctx, ins, outs, axis, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
} }
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
......
...@@ -46,9 +46,9 @@ void BitwiseNotKernel(const Context& dev_ctx, ...@@ -46,9 +46,9 @@ void BitwiseNotKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x}; std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out}; std::vector<DenseTensor*> outs = {out};
funcs::BitwiseNotFunctor<T> func; funcs::BitwiseNotFunctor<T> unary_func;
funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>( funcs::ElementwiseKernel<T, funcs::BitwiseNotFunctor<T>>(
dev_ctx, ins, &outs, -1, func); dev_ctx, ins, &outs, unary_func);
} }
} // namespace phi } // namespace phi
......
...@@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx, ...@@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx,
} }
} }
#endif #endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( using Functor = GeluWithApproximateGradFunctor<T>;
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>()); phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else { } else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( using Functor = GeluWithoutApproximateGradFunctor<T>;
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>()); phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} }
} }
......
...@@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx, ...@@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx,
} }
} }
#endif #endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( using Functor = GeluWithApproximateFunctor<T>;
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>()); phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else { } else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( using Functor = GeluWithoutApproximateFunctor<T>;
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>()); phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} }
} }
......
...@@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx, ...@@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx,
})); }));
} }
template <typename T, template <typename T, typename OutT, typename Context, typename Functor>
typename Context,
template <typename, typename> class TransformOp>
void ReduceGradKernel(const Context& dev_ctx, void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad,
Functor functor) {
auto* in_x = &x; auto* in_x = &x;
auto* d_out = &out_grad; auto* d_out = &out_grad;
auto* d_x = x_grad; auto* d_x = x_grad;
auto pt_out_dtype = x.dtype();
// get reduce_dim and reduce_num for reduce_mean_grad // get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size(); int dim_size = in_x->dims().size();
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
...@@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx,
auto pt_d_out = new_d_out; auto pt_d_out = new_d_out;
auto pt_d_x = *d_x; auto pt_d_x = *d_x;
using MPType = typename kps::details::MPTypeTrait<T>::Type; std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, TransformOp<T, MPType>>( funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>(
dev_ctx, dev_ctx, inputs, &outputs, 0, functor);
&pt_d_out,
&pt_d_x,
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
} }
} // namespace phi } // namespace phi
......
...@@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>( int dim_size = x.dims().size();
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>(
dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
} }
} // namespace phi } // namespace phi
......
...@@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>( using MPType = typename kps::details::MPTypeTrait<T>::Type;
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad); auto out_dtype = x.dtype();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));
dev_ctx.Alloc(d_x, x.dtype());
auto pt_out_dtype = x.dtype();
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
kps::IdentityFunctor<T, MPType>());
} }
} // namespace phi } // namespace phi
...@@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad, ...@@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad,
int64_t, int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx, ...@@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx,
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
CondFunctor<T> func; CondFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>( funcs::ElementwiseKernel<T, CondFunctor<T>, 1>(ctx, ins, &outs, func);
ctx, ins, &outs, -1, func);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册