未验证 提交 61486bf2 编写于 作者: C Chen Weihang 提交者: GitHub

polish fusion kernel naming (#48609)

上级 a686b3cf
......@@ -10,4 +10,6 @@
- Fusion Kernel is generally used to accelerate the combined operation on a certain device. If all devices need to be implemented, the cost is relatively high.
- We don't recommend implementing a pseudo kernel that just throws exception, if not required, it can be not implemented.
3. Fusion Kernel needs to be in the `phi/fusion` namespace
3. Fusion Kernel needs to be in the `phi/fusion` namespace.
4. The file naming of the Fusion Kernel needs to follow the format of `fused_[fusion operation name]_kernel.h/cc/cu`, the kernel function naming of the Fusion Kernel needs to follow the format of `Fused[fusion operation name]Kernel`, and the kernel registration naming of the Fusion Kernel needs to follow the format of `fused_[fusion operation name]`.
......@@ -19,9 +19,9 @@
namespace phi {
template <typename T, typename Context>
void SoftmaxMaskFuseGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
void FusedSoftmaxMaskGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
......@@ -19,9 +19,9 @@
namespace phi {
template <typename T, typename Context>
void SoftmaxMaskFuseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out);
void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out);
} // namespace phi
......@@ -118,10 +118,10 @@ __global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input,
}
template <typename T, typename Context>
void SoftmaxMaskFuseGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
void FusedSoftmaxMaskGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto* grad_x_data = dev_ctx.template Alloc<T>(x_grad);
auto* grad_y_data = out_grad.data<T>();
auto* softmax_rst_data = out.data<T>();
......@@ -196,6 +196,6 @@ void SoftmaxMaskFuseGradKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(fused_softmax_mask_grad,
GPU,
ALL_LAYOUT,
phi::fusion::SoftmaxMaskFuseGradKernel,
phi::fusion::FusedSoftmaxMaskGradKernel,
float,
phi::dtype::float16) {}
......@@ -146,10 +146,10 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
// T only supports fp16
// leave as template only for future update
template <typename T, typename Context>
void SoftmaxMaskFuseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* mask_data = mask.data<T>();
auto* y_data = dev_ctx.template Alloc<T>(out);
......@@ -275,6 +275,6 @@ void SoftmaxMaskFuseKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(fused_softmax_mask,
GPU,
ALL_LAYOUT,
phi::fusion::SoftmaxMaskFuseKernel,
phi::fusion::FusedSoftmaxMaskKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册