layer_norm_op.cu 11.6 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
18
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
19

S
sneaxiy 已提交
20 21 22
namespace paddle {
namespace operators {

P
Pei Yang 已提交
23
template <typename T>
24
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
P
Pei Yang 已提交
25 26 27 28 29
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
30 31
  const auto x_dims = pten::make_ddim(input_shape);
  auto matrix_dim = pten::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
32 33
  int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
  int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
P
Pei Yang 已提交
34 35
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
36
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
37 38 39 40 41 42 43 44 45
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
46 47 48 49 50
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
51
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
65 66
    auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
    auto *var_data = var->mutable_data<U>(ctx.GetPlace());
67

68 69
    auto *void_scale_data = (scale == nullptr ? nullptr : scale->data());
    auto *void_bias_data = (bias == nullptr ? nullptr : bias->data());
70

71 72
    framework::proto::VarType::Type x_dtype =
        framework::TransToProtoVarType(x->dtype());
73 74
    framework::proto::VarType::Type scale_bias_dtype;
    if (void_scale_data != nullptr) {
75
      scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
76
      if (void_bias_data != nullptr) {
77 78
        PADDLE_ENFORCE_EQ(scale_bias_dtype,
                          framework::TransToProtoVarType(bias->dtype()),
79 80 81 82 83
                          platform::errors::InvalidArgument(
                              "Thie Scale and Bias of layer_norm op "
                              "should have the same data type."));
      }
    } else {
84 85 86
      scale_bias_dtype = (void_bias_data != nullptr
                              ? framework::TransToProtoVarType(bias->dtype())
                              : x_dtype);
87 88 89 90 91 92 93 94 95 96
    }

    bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
    if (!is_scale_bias_same_dtype_with_x) {
      PADDLE_ENFORCE_EQ(scale_bias_dtype,
                        framework::DataTypeTrait<U>::DataType(),
                        platform::errors::InvalidArgument(
                            "Unsupported data type of Scale and Bias: %s",
                            framework::DataTypeToString(scale_bias_dtype)));
    }
S
sneaxiy 已提交
97

98
    auto matrix_dim = pten::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
99 100
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
S
sneaxiy 已提交
101 102 103

    auto stream = ctx.cuda_device_context().stream();

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
  do {                                                                     \
    switch (GetDesiredBlockDim(feature_size)) {                            \
      FIXED_BLOCK_DIM_CASE(                                                \
          LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<<  \
              batch_size, kBlockDim, 0, stream>>>(                         \
              x_data, static_cast<const ScaleBiasT *>(void_scale_data),    \
              static_cast<const ScaleBiasT *>(void_bias_data), y_data,     \
              mean_data, var_data, epsilon, feature_size));                \
      default:                                                             \
        PADDLE_THROW(platform::errors::InvalidArgument(                    \
            "Product from begin_norm_axis to end must be larger than 1")); \
        break;                                                             \
    }                                                                      \
  } while (0)

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
#ifdef PADDLE_WITH_CUDA
    bool can_call_1024_kernel = false;
    if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
      can_call_1024_kernel = true;
    }
    if (can_call_1024_kernel) {
      const int WARPS_M = 4;
      const int WARPS_N = 1;
      const int THREADS_PER_WARP = 32;
      const int BYTES_PER_LDG = 16;
      const int VecSize = BYTES_PER_LDG / sizeof(T);

      const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
      const int ROWS_PER_CTA = WARPS_M;

      const int grid = static_cast<int>(
          std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
      if (is_scale_bias_same_dtype_with_x) {
        ln_fwd_1024_kernel<T, U, T, VecSize, WARPS_M, WARPS_N,
                           BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
            batch_size, feature_size, epsilon, x_data,
            static_cast<const T *>(void_scale_data),
            static_cast<const T *>(void_bias_data), mean_data, var_data,
            y_data);
      } else {
        ln_fwd_1024_kernel<T, U, U, VecSize, WARPS_M, WARPS_N,
                           BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
            batch_size, feature_size, epsilon, x_data,
            static_cast<const U *>(void_scale_data),
            static_cast<const U *>(void_bias_data), mean_data, var_data,
            y_data);
      }
152
    } else {
153 154 155 156 157 158 159
#endif
      if (is_scale_bias_same_dtype_with_x) {
        PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
      } else {
        PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
      }
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
160
    }
161 162
#endif

163
#undef PADDLE_LAUNCH_LAYERNORM_FWD
S
sneaxiy 已提交
164 165 166 167 168 169 170 171
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
172
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
173 174 175 176 177 178 179 180 181 182
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
183
    auto *bias = ctx.Input<Tensor>("Bias");
S
sneaxiy 已提交
184 185
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

186 187
    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
188
    auto matrix_dim = pten::flatten_to_2d(x_dims, begin_norm_axis);
189 190 191
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);

S
sneaxiy 已提交
192 193
    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
194

F
furnace 已提交
195 196 197
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

S
sneaxiy 已提交
198 199 200
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

201 202
    framework::proto::VarType::Type x_dtype =
        framework::TransToProtoVarType(x->dtype());
203 204
    framework::proto::VarType::Type scale_bias_dtype;
    if (scale != nullptr) {
205
      scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
206 207 208 209 210
    } else {
      // FIXME(zengjinle): do not find a better way to get the right
      // data type of the d_scale and d_bias if scale == nullptr.
      auto *bias = ctx.Input<Tensor>("Bias");
      if (bias != nullptr) {
211
        scale_bias_dtype = framework::TransToProtoVarType(bias->dtype());
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
      } else {
        scale_bias_dtype = x_dtype;
      }
    }

#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
  do {                                                                     \
    auto *scale_data =                                                     \
        (scale == nullptr ? nullptr : scale->data<ScaleBiasT>());          \
    auto *d_scale_data =                                                   \
        (d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
                                            ctx.GetPlace()));              \
    auto *d_bias_data =                                                    \
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>(   \
                                           ctx.GetPlace()));               \
    auto *d_x_data =                                                       \
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
    LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>(                    \
        x_data, d_y_data, scale_data, mean_data, var_data, d_x_data,       \
        d_scale_data, d_bias_data, epsilon, batch_size, feature_size,      \
        ctx.cuda_device_context());                                        \
  } while (0)

    if (scale_bias_dtype == x_dtype) {
      PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
    } else {
      PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
    }
S
sneaxiy 已提交
240

241
#undef PADDLE_LAUNCH_LAYERNORM_BWD
S
sneaxiy 已提交
242 243
  }
};
F
furnace 已提交
244

P
Pei Yang 已提交
245
template class LayerNormDirectCUDAFunctor<float>;
S
sneaxiy 已提交
246 247 248
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
249
namespace ops = paddle::operators;
F
furnace 已提交
250
namespace plat = paddle::platform;
251 252 253 254 255 256 257 258 259 260 261 262
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
#else
C
chengduoZH 已提交
263 264
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
265
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
266 267
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
268 269
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
270
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
271 272 273
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
274
#endif