未验证 提交 0d878f1a 编写于 作者: N niuliling123 提交者: GitHub

Delete ElementwiseKernel in BroadcastKernel (#42779)

上级 c5d3bc0e
develop 1.8.5 2.4.1 Ligoml-patch-1 add_kylinv10 bugfix-eval-frame-leakgae cherry-pick-fix-customOP-random-fail cp_2.4_fix_numpy fix-numpy-issue fix-run-program-grad-node-mem fix_check fix_custom_device_copy_sync fix_dlpack_for fix_newexe_gc fix_var_stop_gradient_error hack_event incuabte/new_frl incubate/frl_train_eval incubate/new_frl incubate/new_frl_rc incubate/stride layer_norm matmul_double_grad operator_opt pass-compile-eval-frame prv-md-even-more prv-reshape-mkldnn-ut2 release-deleted/2.5 release-rc/2.5 release/2.4 release/2.5 release/llm_2.5 revert-41944-smaller_inference_api_test revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment revert-47325-remove_cudnn_hardcode revert-47645-add_npu_storage_dims revert-48815-set_free_when_no_cache_hit_default_value_true revert-49499-test_ninja_on_ci revert-49654-prim_api_gen revert-49673-modify_get_single_cov revert-49763-fix_static_composite_gen revert-50158-fix_found_inf_bug_for_custom_optimizer revert-50188-refine_optimizer_create_accumulators revert-50335-fix_optminizer_set_auxiliary_var_bug revert-51676-flag_delete revert-51850-fix_softmaxce_dev revert-52175-dev_peak_memory revert-52186-deve revert-52523-test_py38 revert-52912-develop revert-53248-set_cmake_policy revert-54029-fix_windows_compile_bug revert-54068-support_translating_op_attribute revert-54214-modify_cmake_dependencies revert-54370-offline_pslib revert-54391-fix_cmake_md5error revert-54411-fix_cpp17_compile revert-54466-offline_pslib revert-54480-cmake-rocksdb revert-55568-fix_BF16_bug1 revert-56328-new_ir_support_vector_type_place_transfer revert-56366-fix_openssl_bug revert-56545-revert-56366-fix_openssl_bug revert-56620-fix_new_ir_ocr_bug revert-56925-check_inputs_grad_semantic revert-57005-refine_stride_flag sd_conv_linear_autocast semi-auto/rule-base support-0D-sort test_for_Filtetfiles zhiqiu-patch-1 v2.5.1 v2.5.0 v2.5.0-rc1 v2.5.0-rc0 v2.4.2 v2.4.1 v2.4.0 v2.4.0-rc0
无相关合并请求
......@@ -585,26 +585,16 @@ void BroadcastKernel(const KPDevice &ctx,
Functor func) {
std::vector<int> dims_size;
dims_size.reserve(ins.size());
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
if (ins.size() > 0 && outs->size() > 0) {
no_broadcast_flag &= outs->at(0)->dims() == ins[0]->dims();
}
if (no_broadcast_flag) {
phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(ctx, ins, outs, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
template <typename Functor, typename T, typename OutType = T>
......
......@@ -81,11 +81,13 @@ void GeluGradKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
using Functor = GeluWithApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
using Functor = GeluWithoutApproximateGradFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}
......
......@@ -71,11 +71,13 @@ void GeluKernel(const Context& dev_ctx,
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
using Functor = GeluWithApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
using Functor = GeluWithoutApproximateFunctor<T>;
phi::funcs::ElementwiseKernel<T, Functor, 1>(
dev_ctx, ins, &outs, Functor());
}
}
......
......@@ -43,22 +43,19 @@ void ReduceGrad(const GPUContext& dev_ctx,
}));
}
template <typename T,
typename Context,
template <typename, typename> class TransformOp>
template <typename T, typename OutT, typename Context, typename Functor>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
DenseTensor* x_grad,
Functor functor) {
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
auto pt_out_dtype = x.dtype();
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
......@@ -79,14 +76,10 @@ void ReduceGradKernel(const Context& dev_ctx,
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, OutT>(
dev_ctx, inputs, &outputs, 0, functor);
}
} // namespace phi
......
......@@ -29,8 +29,23 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
int dim_size = x.dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>(
dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
}
} // namespace phi
......
......@@ -29,8 +29,40 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
using MPType = typename kps::details::MPTypeTrait<T>::Type;
auto out_dtype = x.dtype();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));
dev_ctx.Alloc(d_x, x.dtype());
auto pt_out_dtype = x.dtype();
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
kps::IdentityFunctor<T, MPType>());
}
} // namespace phi
......@@ -48,4 +80,3 @@ PD_REGISTER_KERNEL(sum_grad,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -40,8 +40,7 @@ void WhereKernel(const Context& ctx,
ctx.template Alloc<T>(out);
CondFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>(
ctx, ins, &outs, -1, func);
funcs::ElementwiseKernel<T, CondFunctor<T>, 1>(ctx, ins, &outs, func);
}
} // namespace phi
......
......@@ -51,9 +51,9 @@ void BitwiseNotKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BitwiseNotFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, -1, func);
funcs::BitwiseNotFunctor<T> unary_func;
funcs::ElementwiseKernel<T, funcs::BitwiseNotFunctor<T>>(
dev_ctx, ins, &outs, unary_func);
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部