未验证 提交 86d92092 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Use phi layernorm (#48276)

上级 d7540a4a
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
DECLARE_bool(use_fast_math);
......@@ -347,6 +348,18 @@ class FusedDropoutHelper {
DropoutParam dropout_param_;
};
template <typename T>
struct PDDataTypeTraits {
using DataType = T;
};
template <>
struct PDDataTypeTraits<phi::dtype::float16> {
// Since LayerNormDirectCUDAFunctor register half type, we need to convert
// phi::float16 to half.
using DataType = half;
};
template <typename T,
typename MaskType,
typename InType = T,
......@@ -383,13 +396,22 @@ class FusedDropoutLayerNormHelper
OutType* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
switch (GetDesiredBlockDim(this->cols_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U, kBlockDim, false, InType, OutType>
<<<this->rows_, kBlockDim, 0, ctx.stream()>>>(
src, gamma, beta, out, mean, variance, epsilon_, this->cols_));
}
using InDataType = typename PDDataTypeTraits<InType>::DataType;
using OutDataType = typename PDDataTypeTraits<OutType>::DataType;
phi::LayerNormDirectCUDAFunctor<InDataType, LayerNormParamType<T>>
layer_norm;
std::vector<int> src_shape{this->rows_, this->cols_};
layer_norm(ctx.stream(),
reinterpret_cast<const InDataType*>(src),
src_shape,
beta,
gamma,
reinterpret_cast<OutDataType*>(out),
mean,
variance,
1,
epsilon_);
}
void LayerNormGrad(const phi::GPUContext& ctx,
......
......@@ -50,6 +50,7 @@ void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream,
}
template class LayerNormDirectCUDAFunctor<float, float>;
template class LayerNormDirectCUDAFunctor<double, double>;
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class LayerNormDirectCUDAFunctor<half, float>;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册