diff --git a/paddle/fluid/operators/layer_norm_op_mlu.cc b/paddle/fluid/operators/layer_norm_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a368af86a3da6cf90e2d12b975fe156e45d98fe6 --- /dev/null +++ b/paddle/fluid/operators/layer_norm_op_mlu.cc @@ -0,0 +1,234 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class LayerNormMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto epsilon = ctx.Attr("epsilon"); + const auto* x = ctx.Input("X"); + const auto* scale = ctx.Input("Scale"); + const auto* bias = ctx.Input("Bias"); + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* variance = ctx.Output("Variance"); + + auto place = ctx.GetPlace(); + + y->mutable_data(place); + mean->mutable_data(place); + variance->mutable_data(place); + + const auto& x_dims = x->dims(); + std::vector scale_bias_axes; + std::vector mean_var_axes; + for (auto i = 0; i < x_dims.size(); ++i) { + if (i >= begin_norm_axis) { + scale_bias_axes.push_back(x_dims[i]); + } else { + mean_var_axes.push_back(x_dims[i]); + } + } + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc y_desc(*y); + MLUCnnlTensorDesc mean_var_desc(mean_var_axes.size(), mean_var_axes.data(), + ToCnnlDataType()); + // cnnl only support both of scale and bias is NULL or not. + if (!scale && !bias) { + MLUCnnl::LayerNormForward( + ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x), + nullptr /*scale_bias_desc*/, nullptr /*scale*/, nullptr /*bias*/, + epsilon, y_desc.get(), GetBasePtr(y), mean_var_desc.get(), + GetBasePtr(mean), GetBasePtr(variance)); + } else { + Tensor tmp_scale(x->dtype()); + if (!scale) { + tmp_scale.mutable_data(phi::make_ddim(scale_bias_axes), place); + FillMLUTensorWithHostValue(ctx, static_cast(1), &tmp_scale); + } else { + tmp_scale = *scale; + } + + Tensor tmp_bias(x->dtype()); + if (!bias) { + tmp_bias.mutable_data(phi::make_ddim(scale_bias_axes), place); + FillMLUTensorWithHostValue(ctx, static_cast(0), &tmp_bias); + } else { + tmp_bias = *bias; + } + + // scale and bias should have same type with x/y + MLUCnnlTensorDesc float32_desc(scale_bias_axes.size(), + scale_bias_axes.data(), CNNL_DTYPE_FLOAT); + MLUCnnlTensorDesc float16_desc(scale_bias_axes.size(), + scale_bias_axes.data(), CNNL_DTYPE_HALF); + cnnlCastDataType_t cast_type = GetCastDataType(VT::FP32, VT::FP16); + + Tensor final_scale(x->dtype()); + if (final_scale.dtype() == DataType::FLOAT16 && + tmp_scale.dtype() == DataType::FLOAT32) { + final_scale.mutable_data(phi::make_ddim(scale_bias_axes), place); + // cast scale to fp16 + MLUCnnl::Cast(ctx, cast_type, float32_desc.get(), + GetBasePtr(&tmp_scale), float16_desc.get(), + GetBasePtr(&final_scale)); + } else { + final_scale = tmp_scale; + } + + Tensor final_bias(x->dtype()); + if (final_bias.dtype() == DataType::FLOAT16 && + tmp_bias.dtype() == DataType::FLOAT32) { + final_bias.mutable_data(phi::make_ddim(scale_bias_axes), place); + // cast bias to fp16 + MLUCnnl::Cast(ctx, cast_type, float32_desc.get(), GetBasePtr(&tmp_bias), + float16_desc.get(), GetBasePtr(&final_bias)); + } else { + final_bias = tmp_bias; + } + + MLUCnnlTensorDesc scale_bias_desc( + scale_bias_axes.size(), scale_bias_axes.data(), ToCnnlDataType()); + MLUCnnl::LayerNormForward( + ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x), + scale_bias_desc.get(), GetBasePtr(&final_scale), + GetBasePtr(&final_bias), epsilon, y_desc.get(), GetBasePtr(y), + mean_var_desc.get(), GetBasePtr(mean), GetBasePtr(variance)); + } + } +}; + +template +class LayerNormGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto* x = ctx.Input("X"); + const auto* mean = ctx.Input("Mean"); + const auto* variance = ctx.Input("Variance"); + const auto* scale = ctx.Input("Scale"); + const auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + auto place = ctx.GetPlace(); + dx->mutable_data(place); + + const auto& x_dims = x->dims(); + std::vector scale_bias_axes; + std::vector mean_var_axes; + for (auto i = 0; i < x_dims.size(); ++i) { + if (i >= begin_norm_axis) { + scale_bias_axes.push_back(x_dims[i]); + } else { + mean_var_axes.push_back(x_dims[i]); + } + } + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc dy_desc(*dy); + MLUCnnlTensorDesc mean_var_desc(mean_var_axes.size(), mean_var_axes.data(), + ToCnnlDataType()); + MLUCnnlTensorDesc dx_desc(*dx); + + Tensor tmp_scale(x->dtype()); + if (!scale) { + tmp_scale.mutable_data(phi::make_ddim(scale_bias_axes), place); + FillMLUTensorWithHostValue(ctx, static_cast(1), &tmp_scale); + } else { + tmp_scale = *scale; + } + + MLUCnnlTensorDesc float32_desc(scale_bias_axes.size(), + scale_bias_axes.data(), CNNL_DTYPE_FLOAT); + MLUCnnlTensorDesc float16_desc(scale_bias_axes.size(), + scale_bias_axes.data(), CNNL_DTYPE_HALF); + cnnlCastDataType_t cast_fp32_to_fp16 = GetCastDataType(VT::FP32, VT::FP16); + cnnlCastDataType_t cast_fp16_to_fp32 = GetCastDataType(VT::FP16, VT::FP32); + + Tensor final_scale(x->dtype()); + if (final_scale.dtype() == DataType::FLOAT16 && + tmp_scale.dtype() == DataType::FLOAT32) { + final_scale.mutable_data(phi::make_ddim(scale_bias_axes), place); + // cast scale to fp16 + MLUCnnl::Cast(ctx, cast_fp32_to_fp16, float32_desc.get(), + GetBasePtr(&tmp_scale), float16_desc.get(), + GetBasePtr(&final_scale)); + } else { + final_scale = tmp_scale; + } + + Tensor tmp_dscale(x->dtype()); + if (dscale && (tmp_dscale.dtype() == dscale->dtype())) { + dscale->mutable_data(place); + tmp_dscale = *dscale; + } else { + tmp_dscale.mutable_data(phi::make_ddim(scale_bias_axes), place); + } + Tensor tmp_dbias(x->dtype()); + if (dbias && (tmp_dbias.dtype() == dbias->dtype())) { + dbias->mutable_data(place); + tmp_dbias = *dbias; + } else { + tmp_dbias.mutable_data(phi::make_ddim(scale_bias_axes), place); + } + + MLUCnnlTensorDesc scale_desc(scale_bias_axes.size(), scale_bias_axes.data(), + ToCnnlDataType()); + MLUCnnl::LayerNormBackward( + ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x), dy_desc.get(), + GetBasePtr(dy), scale_desc.get(), GetBasePtr(&final_scale), + mean_var_desc.get(), GetBasePtr(mean), GetBasePtr(variance), + dx_desc.get(), GetBasePtr(dx), GetBasePtr(&tmp_dscale), + GetBasePtr(&tmp_dbias)); + + if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 && + dscale->dtype() == DataType::FLOAT32)) { + dscale->mutable_data(place); + MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), + GetBasePtr(&tmp_dscale), float32_desc.get(), + GetBasePtr(dscale)); + } + if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 && + dbias->dtype() == DataType::FLOAT32)) { + dbias->mutable_data(place); + MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(), + GetBasePtr(&tmp_dbias), float32_desc.get(), + GetBasePtr(dbias)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(layer_norm, ops::LayerNormMLUKernel, + ops::LayerNormMLUKernel); +REGISTER_OP_MLU_KERNEL(layer_norm_grad, ops::LayerNormGradMLUKernel, + ops::LayerNormGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index eacab468005806f108b5edbd34b977cedae7097f..56c9dd855734d1a5bb172a9ac28274fb80392ddb 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -2077,6 +2077,45 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { } } +/* static */ void MLUCnnl::LayerNormForward( + const ExecutionContext& ctx, int axis, const cnnlTensorDescriptor_t x_desc, + const void* x, const cnnlTensorDescriptor_t weight_bias_desc, + const void* weight, const void* bias, float eps, + const cnnlTensorDescriptor_t y_desc, void* y, + const cnnlTensorDescriptor_t mean_rstd_desc, void* saved_mean, + void* saved_rstd) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlGetLayerNormOpWorkspaceSize(handle, axis, x_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlLayerNormForward(handle, x_desc, x, axis, weight_bias_desc, weight, + bias, eps, workspace_ptr, workspace_size, y_desc, y, + mean_rstd_desc, saved_mean, saved_rstd)); +} + +/* static */ void MLUCnnl::LayerNormBackward( + const ExecutionContext& ctx, int axis, const cnnlTensorDescriptor_t x_desc, + const void* x, const cnnlTensorDescriptor_t diff_z_desc, const void* diff_z, + const cnnlTensorDescriptor_t weight_bias_desc, const void* weight, + const cnnlTensorDescriptor_t mean_rstd_desc, const void* saved_mean, + const void* saved_rstd, const cnnlTensorDescriptor_t diff_x_desc, + void* diff_x, void* diff_weight, void* diff_bias) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlLayerNormBackward( + handle, x_desc, x, axis, diff_z_desc, diff_z, weight_bias_desc, weight, + mean_rstd_desc, saved_mean, saved_rstd, diff_x_desc, diff_x, diff_weight, + diff_bias)); +} + /* static */ void MLUCnnl::QuantizeParam( const ExecutionContext& ctx, const cnnlQuantizeMode_t mode, const int bitwidth, const cnnlTensorDescriptor_t input_desc, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 572b7aa2bbd01e1fb02ad2a9651df4af710d3b24..71ea27d690f11d57c83ca2a63e77aa81a9bc4545 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -146,10 +146,8 @@ const std::map, cnnlCastDataType_t> {{VT::FP16, /*cast to*/ VT::BOOL}, CNNL_CAST_HALF_TO_BOOL}, {{VT::INT32, /*cast to*/ VT::FP32}, CNNL_CAST_INT32_TO_FLOAT}, {{VT::INT32, /*cast to*/ VT::FP16}, CNNL_CAST_INT32_TO_HALF}, - {{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64}, - {{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16}, {{VT::INT32, /*cast to*/ VT::INT8}, CNNL_CAST_INT32_TO_INT8}, - {{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL}, + {{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16}, {{VT::INT16, /*cast to*/ VT::FP32}, CNNL_CAST_INT16_TO_FLOAT}, {{VT::INT16, /*cast to*/ VT::FP16}, CNNL_CAST_INT16_TO_HALF}, {{VT::INT16, /*cast to*/ VT::INT32}, CNNL_CAST_INT16_TO_INT32}, @@ -158,12 +156,21 @@ const std::map, cnnlCastDataType_t> {{VT::INT8, /*cast to*/ VT::INT32}, CNNL_CAST_INT8_TO_INT32}, {{VT::UINT8, /*cast to*/ VT::FP32}, CNNL_CAST_UINT8_TO_FLOAT}, {{VT::UINT8, /*cast to*/ VT::FP16}, CNNL_CAST_UINT8_TO_HALF}, - {{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64}, - {{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32}, {{VT::BOOL, /*cast to*/ VT::FP32}, CNNL_CAST_BOOL_TO_FLOAT}, {{VT::BOOL, /*cast to*/ VT::FP16}, CNNL_CAST_BOOL_TO_HALF}, {{VT::BOOL, /*cast to*/ VT::INT32}, CNNL_CAST_BOOL_TO_INT32}, + {{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32}, + {{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64}, {{VT::INT64, /*cast to*/ VT::INT32}, CNNL_CAST_INT64_TO_INT32}, + {{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL}, + {{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64}, + {{VT::INT8, /*cast to*/ VT::INT16}, CNNL_CAST_INT8_TO_INT16}, + {{VT::FP32, /*cast to*/ VT::FP64}, CNNL_CAST_FLOAT_TO_DOUBLE}, + {{VT::FP64, /*cast to*/ VT::FP32}, CNNL_CAST_DOUBLE_TO_FLOAT}, + {{VT::INT64, /*cast to*/ VT::FP32}, CNNL_CAST_INT64_TO_FLOAT}, + {{VT::INT64, /*cast to*/ VT::FP16}, CNNL_CAST_INT64_TO_HALF}, + {{VT::FP32, /*cast to*/ VT::INT64}, CNNL_CAST_FLOAT_TO_INT64}, + {{VT::FP16, /*cast to*/ VT::INT64}, CNNL_CAST_HALF_TO_INT64}, }; cnnlCastDataType_t GetCastDataType(const VT::Type& src_type, @@ -1103,6 +1110,24 @@ class MLUCnnl { const cnnlTensorDescriptor_t x_backprop_desc, void* x_backprop, void* scale_backprop, void* offset_backprop); + static void LayerNormForward(const ExecutionContext& ctx, int axis, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t weight_bias_desc, + const void* weight, const void* bias, float eps, + const cnnlTensorDescriptor_t y_desc, void* y, + const cnnlTensorDescriptor_t mean_rstd_desc, + void* saved_mean, void* saved_rstd); + + static void LayerNormBackward( + const ExecutionContext& ctx, int axis, + const cnnlTensorDescriptor_t x_desc, const void* x, + const cnnlTensorDescriptor_t diff_z_desc, const void* diff_z, + const cnnlTensorDescriptor_t weight_bias_desc, const void* weight, + const cnnlTensorDescriptor_t mean_rstd_desc, const void* saved_mean, + const void* saved_rstd, const cnnlTensorDescriptor_t diff_x_desc, + void* diff_x, void* diff_weight, void* diff_bias); + static void Transpose(const ExecutionContext& ctx, const std::vector perm, const int input_dim, const cnnlTensorDescriptor_t input_desc, @@ -1230,5 +1255,13 @@ inline void TransposeFromMLUTensor(const ExecutionContext& ctx, GetBasePtr(transformed_output)); } +template +inline void FillMLUTensorWithHostValue(const ExecutionContext& ctx, T value, + Tensor* out) { + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, out_desc.get(), + GetBasePtr(out)); +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/mlu/test_layer_norm_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_layer_norm_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..8b32692020cbffe402dfe1d340fa5a40a411c9ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_layer_norm_op_mlu.py @@ -0,0 +1,309 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from operator import mul +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.nn.functional as F +from functools import reduce +import sys +sys.path.append('..') +from op_test import _set_use_system_allocator +from paddle.fluid import Program, program_guard +from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 +from test_layer_norm_op import _reference_layer_norm_naive, _reference_layer_norm_grad + +paddle.enable_static() + +np.random.random(123) + +_set_use_system_allocator(True) + + +class TestLayerNormOp(unittest.TestCase): + def setUp(self): + self.use_cudnn = True + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def check_forward_backward(self, + shape, + begin_norm_axis, + has_scale=True, + has_bias=True, + y_grad_scale=1.0, + use_mkldnn=False): + def test_with_place(place, + shape, + begin_norm_axis, + use_mkldnn=use_mkldnn): + # attr + epsilon = 0.00001 + x_shape = shape + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + scale_shape = [D] + + np.random.seed(123) + x = np.random.random_sample(x_shape).astype(np.float32) + scale = np.random.random_sample(scale_shape).astype( + np.float32) if has_scale else None + bias = np.random.random_sample(scale_shape).astype( + np.float32) if has_bias else None + y_grad = (np.random.random_sample(x_shape) * + y_grad_scale).astype(np.float32) + + # reference forward & backward + y, mean, variance = _reference_layer_norm_naive( + x, scale, bias, epsilon, begin_norm_axis) + x_grad, scale_grad, bias_grad = _reference_layer_norm_grad( + x, y_grad, scale, bias, mean, variance, begin_norm_axis) + + var_dict = locals() + var_dict['y@GRAD'] = y_grad + var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD'] + if has_scale: + var_names += ['scale'] + if has_bias: + var_names += ['bias'] + ground_truth = {name: var_dict[name] for name in var_names} + + program = fluid.Program() + with fluid.program_guard(program): + block = program.global_block() + for name in ground_truth: + block.create_var( + name=name, + dtype='float32', + shape=ground_truth[name].shape) + inputs = {"X": block.var('x')} + fetch_list = [ + 'y', + 'mean', + 'variance', + 'x@GRAD', + ] + if has_scale: + inputs["Scale"] = block.var('scale') + fetch_list += ['scale@GRAD'] + if has_bias: + inputs["Bias"] = block.var('bias') + fetch_list += ['bias@GRAD'] + layer_norm_op = block.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": block.var('y'), + "Mean": block.var('mean'), # share the same memory + "Variance": + block.var('variance'), # share the same memory + }, + attrs={ + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, + "use_mkldnn": use_mkldnn + }) + # generate backward op_desc + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( + layer_norm_op.desc, set(), []) + grad_op_desc = grad_op_desc_list[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + for var_name in grad_op_desc.output_arg_names(): + block.desc.var(var_name.encode("ascii")) + grad_op_desc.infer_var_type(block.desc) + grad_op_desc.infer_shape(block.desc) + for arg in grad_op_desc.output_arg_names(): + grad_var = block.desc.find_var(arg.encode("ascii")) + grad_var.set_dtype(core.VarDesc.VarType.FP32) + + program._sync_with_cpp() + exe = fluid.Executor(place) + out = exe.run(program, + feed={ + name: var_dict[name] + for name in ['x', 'scale', 'bias', 'y@GRAD'] + }, + fetch_list=fetch_list) + + self.__assert_close(y, out[0], "y") + self.__assert_close(mean, out[1], "mean") + self.__assert_close(1 / np.sqrt(variance), out[2], "variance", + 1e-3) + self.__assert_close(x_grad, out[3], "x_grad") + if has_scale: + self.__assert_close(scale_grad, + out[fetch_list.index('scale@GRAD')], + "scale_grad", 1e-3) + if has_bias: + self.__assert_close(bias_grad, + out[fetch_list.index('bias@GRAD')], + "bias_grad") + + test_with_place(self.place, shape, begin_norm_axis) + + def test_check_forward_backward_with_scale_and_bias(self): + self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=True) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=True, + has_bias=False) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=False) + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) + self.check_forward_backward( + shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1) + self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2) + self.check_forward_backward( + shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=False, + has_bias=True, + y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=True, + has_bias=False, + y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=False, + has_bias=False, + y_grad_scale=0.1) + self.check_forward_backward( + shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True) + + +class TestLayerNormAPI(unittest.TestCase): + def test_case(self): + x = fluid.layers.data( + name='x', + shape=[64, 32, 256], + dtype='float32', + append_batch_size=False) + x = fluid.layers.layer_norm( + x, + scale=True, + shift=True, + begin_norm_axis=1, + epsilon=1e-05, + param_attr=None, + bias_attr=None) + x = fluid.layers.layer_norm( + x, + scale=False, + shift=False, + begin_norm_axis=1, + epsilon=1e-05, + param_attr=None, + bias_attr=None) + x = fluid.layers.layer_norm( + x, + scale=False, + shift=False, + begin_norm_axis=1, + epsilon=1e-05, + param_attr="scale", + bias_attr="shift") + + +class TestDygraphLayerNormAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + paddle.enable_static() + + layer_norm = fluid.LayerNorm([32, 32]) + # the input of LayerNorm must be Variable. + x1 = np.random.random((3, 32, 32)).astype('float32') + self.assertRaises(TypeError, layer_norm, x1) + + # the input dtype of LayerNorm must be float32 or float16 + x2 = fluid.layers.data(name='x2', shape=[3, 32, 32], dtype="int32") + self.assertRaises(TypeError, layer_norm, x2) + + +class TestFP16ScaleBiasLayerNorm(unittest.TestCase): + def check_main(self, x_np, weight_np, bias_np, dtype): + paddle.disable_static() + + weight_np = weight_np.astype(dtype) + bias_np = bias_np.astype(dtype) + + x = paddle.to_tensor(x_np) + weight = paddle.to_tensor(weight_np) + bias = paddle.to_tensor(bias_np) + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False + y = F.layer_norm(x, x.shape[1:], weight, bias) + x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) + y_np = y.numpy().astype('float32') + x_g_np = x_g.numpy().astype('float32') + w_g_np = w_g.numpy().astype('float16') + b_g_np = b_g.numpy().astype('float32') + + paddle.enable_static() + return y_np, x_g_np, w_g_np, b_g_np + + def test_main(self): + x_np = np.random.random([10, 20]).astype('float16') + weight_np = np.random.random([20]).astype('float16') + bias_np = np.random.random([20]).astype('float16') + + y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( + x_np, weight_np, bias_np, 'float16') + y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( + x_np, weight_np, bias_np, 'float32') + + def assert_equal(x, y): + self.assertTrue(np.array_equal(x, y)) + + assert_equal(y_np_1, y_np_2) + assert_equal(x_g_np_1, x_g_np_2) + assert_equal(w_g_np_1, w_g_np_2) + assert_equal(b_g_np_1, b_g_np_2) + + +class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): + def test_main(self): + self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) + _keep_layer_norm_scale_bias_to_fp32(False) + self.assertFalse(_keep_layer_norm_scale_bias_to_fp32()) + _keep_layer_norm_scale_bias_to_fp32(True) + self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) + + +if __name__ == '__main__': + unittest.main()