diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index d5a57dd9ddcad9d70a257738a05b3e5025a2264e..ad15b18d7feaebc27dcaa4c5ed0601cda9ca2107 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -109,7 +109,7 @@ struct PairForLayerNormAddFunctor { template __inline__ __device__ T rsqrt(const T val) { - return ::rsqrt(val); + return static_cast(1) / sqrt(val); } template <> @@ -117,10 +117,17 @@ __inline__ __device__ float rsqrt(const float val) { return rsqrtf(val); } +template <> +__inline__ __device__ double rsqrt(const double val) { + return rsqrt(val); +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> __inline__ __device__ half rsqrt(const half val) { return hrsqrt(val); } +#endif template __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, @@ -841,6 +848,7 @@ class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); @@ -854,12 +862,10 @@ class LayerNormKernel const auto x_dims = x->dims(); auto *x_data = x->data(); auto *y_data = y->mutable_data(ctx.GetPlace()); - auto *mean_data = mean->mutable_data>(ctx.GetPlace()); - auto *var_data = var->mutable_data>(ctx.GetPlace()); - auto *scale_data = - (scale == nullptr ? nullptr : scale->data>()); - auto *bias_data = - (bias == nullptr ? nullptr : bias->data>()); + auto *mean_data = mean->mutable_data(ctx.GetPlace()); + auto *var_data = var->mutable_data(ctx.GetPlace()); + auto *scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *bias_data = (bias == nullptr ? nullptr : bias->data()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); @@ -869,7 +875,7 @@ class LayerNormKernel switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward, + LayerNormForward<<>>( x_data, scale_data, bias_data, y_data, mean_data, var_data, epsilon, feature_size));