layer_norm_op_npu.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
using DataLayout = framework::DataLayout;

template <typename T>
class NormDataType;

template <>
class NormDataType<platform::float16> {
 public:
  // The scaling param type is float for HALF and FLOAT tensors
  using ScalingParamType = const float;
  using BatchNormParamType = float;
};

template <>
class NormDataType<float> {
 public:
  using ScalingParamType = const float;
  using BatchNormParamType = float;
};

template <typename T>
using NormDataType = NormDataType<T>;
template <typename T>
using LayerNormParamType = typename NormDataType<T>::BatchNormParamType;

49 50 51 52
template <typename T>
class LayerNormNPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
53
    using U = LayerNormParamType<T>;
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    const auto epsilon = ctx.Attr<float>("epsilon");
    const auto* x = ctx.Input<Tensor>("X");
    const auto* scale = ctx.Input<Tensor>("Scale");
    const auto* bias = ctx.Input<Tensor>("Bias");
    auto* y = ctx.Output<Tensor>("Y");
    auto* mean = ctx.Output<Tensor>("Mean");
    auto* variance = ctx.Output<Tensor>("Variance");
    const auto& x_dims = x->dims();
    std::vector<int> axes;
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int right = static_cast<int>(matrix_dim[1]);

    // The shape of scale and bias should be equal to x.shape[begin_norm_axis:],
    // required by Ascend.
    for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
      axes.push_back(x_dims[i]);
    }
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    auto place = ctx.GetPlace();
    auto stream =
        ctx.template device_context<paddle::platform::NPUDeviceContext>()
            .stream();

    Tensor default_scale(x->type());
    if (!scale) {
      default_scale.mutable_data<T>(framework::make_ddim(axes), place);
      Tensor value(x->type());
      value.mutable_data<T>({1}, place);
      TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
                       ctx.device_context(), &value);
      auto runner =
          NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
      runner.Run(stream);
      scale = &default_scale;
    } else {
      const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
    }

    Tensor default_bias(x->type());
    if (!bias) {
      default_bias.mutable_data<T>(framework::make_ddim(axes), place);
      Tensor value(x->type());
      value.mutable_data<T>({1}, place);
      TensorFromVector(std::vector<T>{static_cast<T>(0)}, ctx.device_context(),
                       &value);
      auto runner =
          NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}});
      runner.Run(stream);
      bias = &default_bias;
    } else {
      const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
    }
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

    // cast scale from LayerNormParamType to T if needed
    Tensor cast_scale(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        scale->type() == framework::proto::VarType::FP32) {
      cast_scale.Resize(scale->dims());
      cast_scale.mutable_data<T>(ctx.GetPlace());
      auto dst_dtype = ConvertToNpuDtype(x->type());
      auto runner_cast_scale =
          NpuOpRunner("Cast", {*scale}, {cast_scale},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_scale.Run(stream);
    } else {
      cast_scale.ShareDataWith(*scale);
    }

    // cast bias from LayerNormParamType to T if needed
    Tensor cast_bias(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        bias->type() == framework::proto::VarType::FP32) {
      cast_bias.Resize(bias->dims());
      cast_bias.mutable_data<T>(ctx.GetPlace());
      auto dst_dtype = ConvertToNpuDtype(x->type());
      auto runner_cast_bias =
          NpuOpRunner("Cast", {*bias}, {cast_bias},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_bias.Run(stream);
    } else {
      cast_bias.ShareDataWith(*bias);
    }

138
    y->mutable_data<T>(ctx.GetPlace());
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172

    // mean should be of  U type
    Tensor* tmp_mean = mean;
    Tensor cast_mean(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        (scale->type() == framework::proto::VarType::FP32 ||
         bias->type() == framework::proto::VarType::FP32)) {
      cast_mean.Resize(mean->dims());
      cast_mean.mutable_data<T>(ctx.GetPlace());
      tmp_mean = &cast_mean;
      mean->mutable_data<U>(ctx.GetPlace());
    } else {
      mean->mutable_data<T>(ctx.GetPlace());
    }

    // same for variance
    Tensor* tmp_variance = variance;
    Tensor cast_variance(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        (scale->type() == framework::proto::VarType::FP32 ||
         bias->type() == framework::proto::VarType::FP32)) {
      cast_variance.Resize(variance->dims());
      cast_variance.mutable_data<T>(ctx.GetPlace());
      tmp_variance = &cast_variance;
      variance->mutable_data<U>(ctx.GetPlace());
    } else {
      variance->mutable_data<T>(ctx.GetPlace());
    }

    auto runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias},
                              {*y, *tmp_mean, *tmp_variance},
                              {{"begin_norm_axis", begin_norm_axis},
                               {"begin_params_axis", begin_norm_axis},
                               {"epsilon", epsilon}});
173
    runner.Run(stream);
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

    // cast back from FP16 to FP32
    if (x->type() == framework::proto::VarType::FP16 &&
        mean->type() == framework::proto::VarType::FP32) {
      auto dst_dtype = ConvertToNpuDtype(mean->type());
      auto runner_cast_mean =
          NpuOpRunner("Cast", {*tmp_mean}, {*mean},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_mean.Run(stream);
    }
    // same for variance
    if (x->type() == framework::proto::VarType::FP16 &&
        variance->type() == framework::proto::VarType::FP32) {
      auto dst_dtype = ConvertToNpuDtype(variance->type());
      auto runner_cast_variance =
          NpuOpRunner("Cast", {*tmp_variance}, {*variance},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_variance.Run(stream);
    }

194 195 196 197 198 199 200 201 202 203 204 205
    // revert shape of scale and bias
    // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
    // tensor.
    const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
    const_cast<Tensor*>(bias)->Resize(framework::make_ddim({right}));
  }
};

template <typename T>
class LayerNormGradNPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
206
    using U = LayerNormParamType<T>;
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    const auto* x = ctx.Input<Tensor>("X");
    const auto& x_dims = x->dims();
    const auto* mean = ctx.Input<Tensor>("Mean");
    const auto* variance = ctx.Input<Tensor>("Variance");
    const auto* scale = ctx.Input<Tensor>("Scale");
    const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int right = static_cast<int>(matrix_dim[1]);

    std::vector<int> axes;
    for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
      axes.push_back(x_dims[i]);
    }

    auto place = ctx.GetPlace();
    auto stream =
        ctx.template device_context<paddle::platform::NPUDeviceContext>()
            .stream();

    // No need to compute any gradient, jusr return
    if (!dx && !dscale && !dbias) {
      return;
    }

    // The rank of mean should be equal to x, required by Ascend.
    std::vector<int> new_shape;
    for (auto i = 0; i < begin_norm_axis; ++i) {
      new_shape.push_back(x_dims[i]);
    }
    for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
      new_shape.push_back(1);
    }

    auto mean_dims = mean->dims();
    const_cast<Tensor*>(mean)->Resize(framework::make_ddim({new_shape}));
    const_cast<Tensor*>(variance)->Resize(framework::make_ddim({new_shape}));

    Tensor default_scale(x->type());
    if (!scale) {
      default_scale.mutable_data<T>(framework::make_ddim(axes), place);
      Tensor value(x->type());
      value.mutable_data<T>({1}, place);
      TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
                       ctx.device_context(), &value);
      auto runner =
          NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
      runner.Run(stream);
      scale = &default_scale;
    } else {
      const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
    }

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
    // cast scale from LayerNormParamType to T if needed
    Tensor cast_scale(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        scale->type() == framework::proto::VarType::FP32) {
      cast_scale.Resize(scale->dims());
      cast_scale.mutable_data<T>(ctx.GetPlace());
      auto dst_dtype = ConvertToNpuDtype(x->type());
      auto runner_cast_scale =
          NpuOpRunner("Cast", {*scale}, {cast_scale},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_scale.Run(stream);
    } else {
      cast_scale.ShareDataWith(*scale);
    }

    // cast mean from LayerNormParamType to T if needed
    Tensor cast_mean(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        mean->type() == framework::proto::VarType::FP32) {
      cast_mean.Resize(mean->dims());
      cast_mean.mutable_data<T>(ctx.GetPlace());
      auto dst_dtype = ConvertToNpuDtype(x->type());
      auto runner_cast_mean =
          NpuOpRunner("Cast", {*mean}, {cast_mean},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_mean.Run(stream);
    } else {
      cast_mean.ShareDataWith(*mean);
    }

    // cast variance from LayerNormParamType to T if needed
    Tensor cast_variance(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        variance->type() == framework::proto::VarType::FP32) {
      cast_variance.Resize(variance->dims());
      cast_variance.mutable_data<T>(ctx.GetPlace());
      auto dst_dtype = ConvertToNpuDtype(x->type());
      auto runner_cast_variance =
          NpuOpRunner("Cast", {*variance}, {cast_variance},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_variance.Run(stream);
    } else {
      cast_variance.ShareDataWith(*variance);
    }

309 310 311 312 313
    Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
    dx = (dx == nullptr) ? &dx_ : dx;
    dscale = (dscale == nullptr) ? &dscale_ : dscale;
    dbias = (dbias == nullptr) ? &dbias_ : dbias;

314 315 316
    dx->Resize(x->dims());
    dx->mutable_data<T>(ctx.GetPlace());

317 318 319 320
    dscale->Resize(framework::make_ddim(axes));

    dbias->Resize(framework::make_ddim(axes));

321 322 323 324 325 326 327 328 329 330 331 332 333
    // dscale should be of  U type
    Tensor* tmp_dscale = dscale;
    Tensor cast_dscale(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        (mean->type() == framework::proto::VarType::FP32 ||
         variance->type() == framework::proto::VarType::FP32)) {
      cast_dscale.Resize(dscale->dims());
      cast_dscale.mutable_data<T>(ctx.GetPlace());
      tmp_dscale = &cast_dscale;
      dscale->mutable_data<U>(ctx.GetPlace());
    } else {
      dscale->mutable_data<T>(ctx.GetPlace());
    }
334

335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    // same for dbias
    Tensor* tmp_dbias = dbias;
    Tensor cast_dbias(x->type());
    if (x->type() == framework::proto::VarType::FP16 &&
        (mean->type() == framework::proto::VarType::FP32 ||
         variance->type() == framework::proto::VarType::FP32)) {
      cast_dbias.Resize(dbias->dims());
      cast_dbias.mutable_data<T>(ctx.GetPlace());
      tmp_dbias = &cast_dbias;
      dbias->mutable_data<U>(ctx.GetPlace());
    } else {
      dbias->mutable_data<T>(ctx.GetPlace());
    }

    auto runner = NpuOpRunner("LayerNormGrad",
                              {*dy, *x, cast_variance, cast_mean, cast_scale},
                              {*dx, *tmp_dscale, *tmp_dbias}, {});
352 353
    runner.Run(stream);

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    // cast back from FP16 to FP32
    if (x->type() == framework::proto::VarType::FP16 &&
        dscale->type() == framework::proto::VarType::FP32) {
      auto dst_dtype = ConvertToNpuDtype(dscale->type());
      auto runner_cast_dscale =
          NpuOpRunner("Cast", {*tmp_dscale}, {*dscale},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_dscale.Run(stream);
    }
    // same for dbias
    if (x->type() == framework::proto::VarType::FP16 &&
        dbias->type() == framework::proto::VarType::FP32) {
      auto dst_dtype = ConvertToNpuDtype(dbias->type());
      auto runner_cast_dbias =
          NpuOpRunner("Cast", {*tmp_dbias}, {*dbias},
                      {{"dst_type", static_cast<int>(dst_dtype)}});
      runner_cast_dbias.Run(stream);
    }

373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    const_cast<Tensor*>(mean)->Resize(mean_dims);
    const_cast<Tensor*>(variance)->Resize(mean_dims);
    const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
    dscale->Resize(framework::make_ddim({right}));
    dbias->Resize(framework::make_ddim({right}));
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(layer_norm, ops::LayerNormNPUKernel<float>,
                       ops::LayerNormNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(layer_norm_grad, ops::LayerNormGradNPUKernel<float>,
                       ops::LayerNormGradNPUKernel<plat::float16>);