未验证 提交 84bb7a96 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Adjust files of fusion kernel in PHI (#52420)

* update readme

* remove unused header file

* fix bug

* fix onednn

* fix onednn

* rename header file
上级 690767ed
......@@ -7,9 +7,8 @@
2. We don't require fusion kernel to have implementations for all devices
- Fusion Kernel is generally used to accelerate the combined operation on a certain device. If all devices need to be implemented, the cost is relatively high.
- Fusion Kernel is generally used to accelerate the combined operation on a certain backend. If all backends need to be implemented, the cost is relatively high.
- We don't recommend implementing a pseudo kernel that just throws exception, if not required, it can be not implemented.
- If the kernel is only implemented on a certain backend, we recommend add a suffix of backend in kernel name (such as `fused_matmul_onednn`, `fused_fc_xpu`).
3. Fusion Kernel needs to be in the `phi/fusion` namespace.
4. The file naming of the Fusion Kernel needs to follow the format of `fused_[fusion operation name]_kernel.h/cc/cu`, the kernel function naming of the Fusion Kernel needs to follow the format of `Fused[fusion operation name]Kernel`, and the kernel registration naming of the Fusion Kernel needs to follow the format of `fused_[fusion operation name]`.
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/moe_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedDropoutAddGradKernel(const Context& dev_ctx,
const DenseTensor& seed_offset,
const DenseTensor& out_grad,
const Scalar& p,
bool is_test,
const std::string& mode,
bool fix_seed,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedLinearParamGradAdd(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
DenseTensor *dweight_out,
DenseTensor *dbias_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedSoftmaxMaskGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out);
} // namespace phi
......@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h"
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_add_utils.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
......@@ -33,6 +32,7 @@ static inline int NumBlocks(const int N) {
}
namespace phi {
namespace fusion {
template <typename T, typename MT>
__global__ void FuseScaleAddGrad(const T* grad,
......@@ -220,12 +220,13 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_dropout_add_grad,
GPU,
ALL_LAYOUT,
phi::FusedDropoutAddGradKernel,
phi::fusion::FusedDropoutAddGradKernel,
float,
double,
phi::dtype::bfloat16,
......
......@@ -11,8 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_add_utils.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
......@@ -21,6 +23,7 @@
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
namespace phi {
namespace fusion {
template <typename T1, typename T2 = T1, typename OutT = T1>
struct NoMaskFwFunctor {
......@@ -204,12 +207,13 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_dropout_add,
GPU,
ALL_LAYOUT,
phi::FusedDropoutAddKernel,
phi::fusion::FusedDropoutAddKernel,
float,
double,
phi::dtype::bfloat16,
......
......@@ -14,23 +14,10 @@
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace phi {
template <typename T, typename Context>
void FusedDropoutAddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* seed_offset);
namespace fusion {
template <typename Context>
static inline std::vector<size_t> GetRandomCudaProp(int numel,
......@@ -52,4 +39,5 @@ static inline std::vector<size_t> GetRandomCudaProp(int numel,
return {grid_size, block_size, offset, main_offset};
}
} // namespace fusion
} // namespace phi
......@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
......@@ -23,6 +25,7 @@
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
namespace fusion {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
......@@ -201,12 +204,13 @@ void FusedLinearParamGradAdd(const Context &ctx,
}
#endif
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_linear_param_grad_add,
GPU,
ALL_LAYOUT,
phi::FusedLinearParamGradAdd,
phi::fusion::FusedLinearParamGradAdd,
float,
double,
phi::dtype::float16,
......
......@@ -14,8 +14,6 @@
#include <algorithm>
#include "paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MemoryEfficientAttentionBackwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const DenseTensor& output,
const DenseTensor& logsumexp,
const DenseTensor& seed_and_offset,
const DenseTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
DenseTensor* query_grad,
DenseTensor* key_grad,
DenseTensor* value_grad,
DenseTensor* bias_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MemoryEfficientAttentionForwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const paddle::optional<DenseTensor>& causal_diagonal,
const paddle::optional<DenseTensor>& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
DenseTensor* output,
DenseTensor* logsumexp,
DenseTensor* seed_and_offset);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MoeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& gate,
const DenseTensor& bmm0,
const DenseTensor& bias0,
const DenseTensor& bmm1,
const DenseTensor& bias1,
const std::string& act_type,
DenseTensor* output);
} // namespace phi
......@@ -16,6 +16,7 @@
#include "paddle/phi/kernels/onednn/conv_function.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedConv2DKernel(const Context& dev_ctx,
......@@ -132,16 +133,17 @@ void FusedConv3DKernel(const Context& dev_ctx,
out);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_conv2d,
OneDNN,
ONEDNN,
phi::FusedConv2DKernel,
phi::fusion::FusedConv2DKernel,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
PD_REGISTER_KERNEL(
fused_conv3d, OneDNN, ONEDNN, phi::FusedConv3DKernel, float) {}
fused_conv3d, OneDNN, ONEDNN, phi::fusion::FusedConv3DKernel, float) {}
......@@ -16,6 +16,7 @@
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, dnnl::algorithm BINARY_OP>
void FusedElementwiseKernel(const OneDNNContext& dev_ctx,
......@@ -153,12 +154,13 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx,
DEFINE_ONEDNN_ELEMENTWISE_KERNEL(FusedMultiply, dnnl::algorithm::binary_mul)
DEFINE_ONEDNN_ELEMENTWISE_KERNEL(FusedDivide, dnnl::algorithm::binary_div)
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_elementwise_mul,
OneDNN,
ONEDNN,
phi::FusedMultiplyKernel,
phi::fusion::FusedMultiplyKernel,
float,
phi::dtype::bfloat16,
int8_t,
......@@ -167,7 +169,7 @@ PD_REGISTER_KERNEL(fused_elementwise_mul,
PD_REGISTER_KERNEL(fused_elementwise_div,
OneDNN,
ONEDNN,
phi::FusedDivideKernel,
phi::fusion::FusedDivideKernel,
float,
phi::dtype::bfloat16,
int8_t,
......
......@@ -26,6 +26,7 @@ using dnnl::stream;
using phi::ReshapeToMatrix;
namespace phi {
namespace fusion {
template <typename XT, typename YT, typename OT>
class FusedMatmulOneDNNHandler
......@@ -514,12 +515,13 @@ void FusedMatmulKernel(const Context &dev_ctx,
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_matmul,
OneDNN,
ONEDNN,
phi::FusedMatmulKernel,
phi::fusion::FusedMatmulKernel,
float,
phi::dtype::bfloat16,
int8_t,
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedSoftplusKernel(const Context& dev_ctx,
......@@ -55,11 +56,12 @@ void FusedSoftplusKernel(const Context& dev_ctx,
out->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_softplus,
OneDNN,
ONEDNN,
phi::FusedSoftplusKernel,
phi::fusion::FusedSoftplusKernel,
float,
phi::dtype::bfloat16) {}
......@@ -17,6 +17,7 @@
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace fusion {
void SetInMemDescWithSqueeze2FuseSupport(
const std::vector<int> fused_squeeze2_axes,
......@@ -166,12 +167,14 @@ void FusedTransposeKernel(const Context& dev_ctx,
out->set_mem_desc(out_md);
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_transpose,
OneDNN,
ONEDNN,
phi::FusedTransposeKernel,
phi::fusion::FusedTransposeKernel,
float,
uint8_t,
int8_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册