layer_norm_op.cu 9.8 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
17
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
18

S
sneaxiy 已提交
19 20 21
namespace paddle {
namespace operators {

P
Pei Yang 已提交
22
template <typename T>
23
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
P
Pei Yang 已提交
24 25 26 27 28 29 30
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
31 32
  int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
  int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
P
Pei Yang 已提交
33 34
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
35
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
36 37 38 39 40 41 42 43 44
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
45 46 47 48 49
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
50
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
64 65
    auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
    auto *var_data = var->mutable_data<U>(ctx.GetPlace());
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

    auto *void_scale_data = (scale == nullptr ? nullptr : scale->data<void>());
    auto *void_bias_data = (bias == nullptr ? nullptr : bias->data<void>());

    framework::proto::VarType::Type x_dtype = x->type();
    framework::proto::VarType::Type scale_bias_dtype;
    if (void_scale_data != nullptr) {
      scale_bias_dtype = scale->type();
      if (void_bias_data != nullptr) {
        PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(),
                          platform::errors::InvalidArgument(
                              "Thie Scale and Bias of layer_norm op "
                              "should have the same data type."));
      }
    } else {
      scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype);
    }

    bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
    if (!is_scale_bias_same_dtype_with_x) {
      PADDLE_ENFORCE_EQ(scale_bias_dtype,
                        framework::DataTypeTrait<U>::DataType(),
                        platform::errors::InvalidArgument(
                            "Unsupported data type of Scale and Bias: %s",
                            framework::DataTypeToString(scale_bias_dtype)));
    }
S
sneaxiy 已提交
92 93

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
94 95
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
S
sneaxiy 已提交
96 97 98

    auto stream = ctx.cuda_device_context().stream();

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
  do {                                                                     \
    switch (GetDesiredBlockDim(feature_size)) {                            \
      FIXED_BLOCK_DIM_CASE(                                                \
          LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<<  \
              batch_size, kBlockDim, 0, stream>>>(                         \
              x_data, static_cast<const ScaleBiasT *>(void_scale_data),    \
              static_cast<const ScaleBiasT *>(void_bias_data), y_data,     \
              mean_data, var_data, epsilon, feature_size));                \
      default:                                                             \
        PADDLE_THROW(platform::errors::InvalidArgument(                    \
            "Product from begin_norm_axis to end must be larger than 1")); \
        break;                                                             \
    }                                                                      \
  } while (0)

    if (is_scale_bias_same_dtype_with_x) {
      PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
    } else {
      PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
S
sneaxiy 已提交
119
    }
120
#undef PADDLE_LAUNCH_LAYERNORM_FWD
S
sneaxiy 已提交
121 122 123 124 125 126 127 128
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
129
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
130 131 132 133 134 135 136 137 138 139
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
140
    auto *bias = ctx.Input<Tensor>("Bias");
S
sneaxiy 已提交
141 142
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

143 144 145 146 147 148
    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);

S
sneaxiy 已提交
149 150
    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
151

F
furnace 已提交
152 153 154
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

S
sneaxiy 已提交
155 156 157
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    framework::proto::VarType::Type x_dtype = x->type();
    framework::proto::VarType::Type scale_bias_dtype;
    if (scale != nullptr) {
      scale_bias_dtype = scale->type();
    } else {
      // FIXME(zengjinle): do not find a better way to get the right
      // data type of the d_scale and d_bias if scale == nullptr.
      auto *bias = ctx.Input<Tensor>("Bias");
      if (bias != nullptr) {
        scale_bias_dtype = bias->saved_type();
      } else {
        scale_bias_dtype = x_dtype;
      }
    }

#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
  do {                                                                     \
    auto *scale_data =                                                     \
        (scale == nullptr ? nullptr : scale->data<ScaleBiasT>());          \
    auto *d_scale_data =                                                   \
        (d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
                                            ctx.GetPlace()));              \
    auto *d_bias_data =                                                    \
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>(   \
                                           ctx.GetPlace()));               \
    auto *d_x_data =                                                       \
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
    LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>(                    \
        x_data, d_y_data, scale_data, mean_data, var_data, d_x_data,       \
        d_scale_data, d_bias_data, epsilon, batch_size, feature_size,      \
        ctx.cuda_device_context());                                        \
  } while (0)

    if (scale_bias_dtype == x_dtype) {
      PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
    } else {
      PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
    }
S
sneaxiy 已提交
196

197
#undef PADDLE_LAUNCH_LAYERNORM_BWD
S
sneaxiy 已提交
198 199
  }
};
F
furnace 已提交
200

P
Pei Yang 已提交
201
template class LayerNormDirectCUDAFunctor<float>;
S
sneaxiy 已提交
202 203 204
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
205
namespace ops = paddle::operators;
F
furnace 已提交
206
namespace plat = paddle::platform;
207 208 209 210 211 212 213 214 215 216 217 218
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
#else
C
chengduoZH 已提交
219 220
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
221
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
222 223
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
224 225
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
226
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
227 228 229
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
230
#endif