/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; using DataLayout = framework::DataLayout; template class NormDataType; template <> class NormDataType { public: // The scaling param type is float for HALF and FLOAT tensors using ScalingParamType = const float; using BatchNormParamType = float; }; template <> class NormDataType { public: using ScalingParamType = const float; using BatchNormParamType = float; }; template using NormDataType = NormDataType; template using LayerNormParamType = typename NormDataType::BatchNormParamType; template class LayerNormNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using U = LayerNormParamType; const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto epsilon = ctx.Attr("epsilon"); const auto* x = ctx.Input("X"); const auto* scale = ctx.Input("Scale"); const auto* bias = ctx.Input("Bias"); auto* y = ctx.Output("Y"); auto* mean = ctx.Output("Mean"); auto* variance = ctx.Output("Variance"); const auto& x_dims = x->dims(); std::vector axes; auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int right = static_cast(matrix_dim[1]); // The shape of scale and bias should be equal to x.shape[begin_norm_axis:], // required by Ascend. for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { axes.push_back(x_dims[i]); } auto place = ctx.GetPlace(); auto stream = ctx.template device_context() .stream(); Tensor default_scale(x->type()); if (!scale) { default_scale.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); FillNpuTensorWithConstant(&value, static_cast(1.0)); const auto& runner = NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); runner.Run(stream); scale = &default_scale; } else { const_cast(scale)->Resize(framework::make_ddim(axes)); } Tensor default_bias(x->type()); if (!bias) { default_bias.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); FillNpuTensorWithConstant(&value, static_cast(0)); const auto& runner = NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}}); runner.Run(stream); bias = &default_bias; } else { const_cast(bias)->Resize(framework::make_ddim(axes)); } // cast scale from LayerNormParamType to T if needed Tensor cast_scale(x->type()); if (x->type() == framework::proto::VarType::FP16 && scale->type() == framework::proto::VarType::FP32) { cast_scale.Resize(scale->dims()); cast_scale.mutable_data(ctx.GetPlace()); auto dst_dtype = ConvertToNpuDtype(x->type()); const auto& runner_cast_scale = NpuOpRunner("Cast", {*scale}, {cast_scale}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_scale.Run(stream); } else { cast_scale.ShareDataWith(*scale); } // cast bias from LayerNormParamType to T if needed Tensor cast_bias(x->type()); if (x->type() == framework::proto::VarType::FP16 && bias->type() == framework::proto::VarType::FP32) { cast_bias.Resize(bias->dims()); cast_bias.mutable_data(ctx.GetPlace()); auto dst_dtype = ConvertToNpuDtype(x->type()); const auto& runner_cast_bias = NpuOpRunner("Cast", {*bias}, {cast_bias}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_bias.Run(stream); } else { cast_bias.ShareDataWith(*bias); } y->mutable_data(ctx.GetPlace()); // mean should be of U type Tensor* tmp_mean = mean; Tensor cast_mean(x->type()); if (x->type() == framework::proto::VarType::FP16 && (scale->type() == framework::proto::VarType::FP32 || bias->type() == framework::proto::VarType::FP32)) { cast_mean.Resize(mean->dims()); cast_mean.mutable_data(ctx.GetPlace()); tmp_mean = &cast_mean; mean->mutable_data(ctx.GetPlace()); } else { mean->mutable_data(ctx.GetPlace()); } // same for variance Tensor* tmp_variance = variance; Tensor cast_variance(x->type()); if (x->type() == framework::proto::VarType::FP16 && (scale->type() == framework::proto::VarType::FP32 || bias->type() == framework::proto::VarType::FP32)) { cast_variance.Resize(variance->dims()); cast_variance.mutable_data(ctx.GetPlace()); tmp_variance = &cast_variance; variance->mutable_data(ctx.GetPlace()); } else { variance->mutable_data(ctx.GetPlace()); } const auto& runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias}, {*y, *tmp_mean, *tmp_variance}, {{"begin_norm_axis", begin_norm_axis}, {"begin_params_axis", begin_norm_axis}, {"epsilon", epsilon}}); runner.Run(stream); // cast back from FP16 to FP32 if (x->type() == framework::proto::VarType::FP16 && mean->type() == framework::proto::VarType::FP32) { auto dst_dtype = ConvertToNpuDtype(mean->type()); const auto& runner_cast_mean = NpuOpRunner("Cast", {*tmp_mean}, {*mean}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_mean.Run(stream); } // same for variance if (x->type() == framework::proto::VarType::FP16 && variance->type() == framework::proto::VarType::FP32) { auto dst_dtype = ConvertToNpuDtype(variance->type()); const auto& runner_cast_variance = NpuOpRunner("Cast", {*tmp_variance}, {*variance}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_variance.Run(stream); } // revert shape of scale and bias // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input // tensor. const_cast(scale)->Resize(framework::make_ddim({right})); const_cast(bias)->Resize(framework::make_ddim({right})); } }; template class LayerNormGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using U = LayerNormParamType; const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto* x = ctx.Input("X"); const auto& x_dims = x->dims(); const auto* mean = ctx.Input("Mean"); const auto* variance = ctx.Input("Variance"); const auto* scale = ctx.Input("Scale"); const auto* dy = ctx.Input(framework::GradVarName("Y")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dscale = ctx.Output(framework::GradVarName("Scale")); auto* dbias = ctx.Output(framework::GradVarName("Bias")); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int right = static_cast(matrix_dim[1]); std::vector axes; for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { axes.push_back(x_dims[i]); } auto place = ctx.GetPlace(); auto stream = ctx.template device_context() .stream(); // No need to compute any gradient, jusr return if (!dx && !dscale && !dbias) { return; } // The rank of mean should be equal to x, required by Ascend. std::vector new_shape; for (auto i = 0; i < begin_norm_axis; ++i) { new_shape.push_back(x_dims[i]); } for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { new_shape.push_back(1); } auto mean_dims = mean->dims(); const_cast(mean)->Resize(framework::make_ddim({new_shape})); const_cast(variance)->Resize(framework::make_ddim({new_shape})); Tensor default_scale(x->type()); if (!scale) { default_scale.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); FillNpuTensorWithConstant(&value, static_cast(1.0)); const auto& runner = NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); runner.Run(stream); scale = &default_scale; } else { const_cast(scale)->Resize(framework::make_ddim(axes)); } // cast scale from LayerNormParamType to T if needed Tensor cast_scale(x->type()); if (x->type() == framework::proto::VarType::FP16 && scale->type() == framework::proto::VarType::FP32) { cast_scale.Resize(scale->dims()); cast_scale.mutable_data(ctx.GetPlace()); auto dst_dtype = ConvertToNpuDtype(x->type()); const auto& runner_cast_scale = NpuOpRunner("Cast", {*scale}, {cast_scale}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_scale.Run(stream); } else { cast_scale.ShareDataWith(*scale); } // cast mean from LayerNormParamType to T if needed Tensor cast_mean(x->type()); if (x->type() == framework::proto::VarType::FP16 && mean->type() == framework::proto::VarType::FP32) { cast_mean.Resize(mean->dims()); cast_mean.mutable_data(ctx.GetPlace()); auto dst_dtype = ConvertToNpuDtype(x->type()); const auto& runner_cast_mean = NpuOpRunner("Cast", {*mean}, {cast_mean}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_mean.Run(stream); } else { cast_mean.ShareDataWith(*mean); } // cast variance from LayerNormParamType to T if needed Tensor cast_variance(x->type()); if (x->type() == framework::proto::VarType::FP16 && variance->type() == framework::proto::VarType::FP32) { cast_variance.Resize(variance->dims()); cast_variance.mutable_data(ctx.GetPlace()); auto dst_dtype = ConvertToNpuDtype(x->type()); const auto& runner_cast_variance = NpuOpRunner("Cast", {*variance}, {cast_variance}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_variance.Run(stream); } else { cast_variance.ShareDataWith(*variance); } Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type()); dx = (dx == nullptr) ? &dx_ : dx; dscale = (dscale == nullptr) ? &dscale_ : dscale; dbias = (dbias == nullptr) ? &dbias_ : dbias; dx->Resize(x->dims()); dx->mutable_data(ctx.GetPlace()); dscale->Resize(framework::make_ddim(axes)); dbias->Resize(framework::make_ddim(axes)); // dscale should be of U type Tensor* tmp_dscale = dscale; Tensor cast_dscale(x->type()); if (x->type() == framework::proto::VarType::FP16 && (mean->type() == framework::proto::VarType::FP32 || variance->type() == framework::proto::VarType::FP32)) { cast_dscale.Resize(dscale->dims()); cast_dscale.mutable_data(ctx.GetPlace()); tmp_dscale = &cast_dscale; dscale->mutable_data(ctx.GetPlace()); } else { dscale->mutable_data(ctx.GetPlace()); } // same for dbias Tensor* tmp_dbias = dbias; Tensor cast_dbias(x->type()); if (x->type() == framework::proto::VarType::FP16 && (mean->type() == framework::proto::VarType::FP32 || variance->type() == framework::proto::VarType::FP32)) { cast_dbias.Resize(dbias->dims()); cast_dbias.mutable_data(ctx.GetPlace()); tmp_dbias = &cast_dbias; dbias->mutable_data(ctx.GetPlace()); } else { dbias->mutable_data(ctx.GetPlace()); } const auto& runner = NpuOpRunner( "LayerNormGrad", {*dy, *x, cast_variance, cast_mean, cast_scale}, {*dx, *tmp_dscale, *tmp_dbias}, {}); runner.Run(stream); // cast back from FP16 to FP32 if (x->type() == framework::proto::VarType::FP16 && dscale->type() == framework::proto::VarType::FP32) { auto dst_dtype = ConvertToNpuDtype(dscale->type()); const auto& runner_cast_dscale = NpuOpRunner("Cast", {*tmp_dscale}, {*dscale}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_dscale.Run(stream); } // same for dbias if (x->type() == framework::proto::VarType::FP16 && dbias->type() == framework::proto::VarType::FP32) { auto dst_dtype = ConvertToNpuDtype(dbias->type()); const auto& runner_cast_dbias = NpuOpRunner("Cast", {*tmp_dbias}, {*dbias}, {{"dst_type", static_cast(dst_dtype)}}); runner_cast_dbias.Run(stream); } const_cast(mean)->Resize(mean_dims); const_cast(variance)->Resize(mean_dims); const_cast(scale)->Resize(framework::make_ddim({right})); dscale->Resize(framework::make_ddim({right})); dbias->Resize(framework::make_ddim({right})); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(layer_norm, ops::LayerNormNPUKernel, ops::LayerNormNPUKernel); REGISTER_OP_NPU_KERNEL(layer_norm_grad, ops::LayerNormGradNPUKernel, ops::LayerNormGradNPUKernel);