未验证 提交 0d878f1a 编写于 作者: N niuliling123 提交者: GitHub

Delete ElementwiseKernel in BroadcastKernel (#42779)

上级 c5d3bc0e
......@@ -585,26 +585,16 @@ void BroadcastKernel(const KPDevice &ctx,
Functor func) {
std::vector<int> dims_size;
dims_size.reserve(ins.size());
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
if (ins.size() > 0 && outs->size() > 0) {
no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims();
}
if (no_broadcast_flag) {
phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(ctx, ins, outs, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
}
template <typename Functor, typename T, typename OutType = T>
......
......@@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
using Functor = GeluWithApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
using Functor = GeluWithoutApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}
......
......@@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
using Functor = GeluWithApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
using Functor = GeluWithoutApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}
......
......@@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx,
}));
}
template <typename T,
typename Context,
template <typename, typename> class TransformOp>
template <typename T, typename OutT, typename Context, typename Functor>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
DenseTensor* x_grad,
Functor functor) {
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
auto pt_out_dtype = x.dtype();
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
......@@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx,
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>(
dev_ctx, inputs, &outputs, 0, functor);
}
} // namespace phi
......
......@@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
int dim_size = x.dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>(
dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
}
} // namespace phi
......
......@@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
using MPType = typename kps::details::MPTypeTrait<T>::Type;
auto out_dtype = x.dtype();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));
dev_ctx.Alloc(d_x, x.dtype());
auto pt_out_dtype = x.dtype();
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
kps::IdentityFunctor<T, MPType>());
}
} // namespace phi
......@@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx,
ctx.template Alloc<T>(out);
CondFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>(
ctx, ins, &outs, -1, func);
funcs::ElementwiseKernel<T, CondFunctor<T>, 1>(ctx, ins, &outs, func);
}
} // namespace phi
......
......@@ -51,9 +51,9 @@ void BitwiseNotKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BitwiseNotFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, -1, func);
funcs::BitwiseNotFunctor<T> unary_func;
funcs::ElementwiseKernel<T, funcs::BitwiseNotFunctor<T>>(
dev_ctx, ins, &outs, unary_func);
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册