diff --git a/paddle/fluid/operators/layer_norm_op_npu.cc b/paddle/fluid/operators/layer_norm_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..447eda1a8a4c43c9bb6c7dfb325ac1436b100517 --- /dev/null +++ b/paddle/fluid/operators/layer_norm_op_npu.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class LayerNormNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto epsilon = ctx.Attr("epsilon"); + const auto* x = ctx.Input("X"); + const auto* scale = ctx.Input("Scale"); + const auto* bias = ctx.Input("Bias"); + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* variance = ctx.Output("Variance"); + const auto& x_dims = x->dims(); + std::vector axes; + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int right = static_cast(matrix_dim[1]); + + // The shape of scale and bias should be equal to x.shape[begin_norm_axis:], + // required by Ascend. + for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { + axes.push_back(x_dims[i]); + } + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + Tensor default_scale(x->type()); + if (!scale) { + default_scale.mutable_data(framework::make_ddim(axes), place); + Tensor value(x->type()); + value.mutable_data({1}, place); + TensorFromVector(std::vector{static_cast(1.0)}, + ctx.device_context(), &value); + auto runner = + NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); + runner.Run(stream); + scale = &default_scale; + } else { + const_cast(scale)->Resize(framework::make_ddim(axes)); + } + + Tensor default_bias(x->type()); + if (!bias) { + default_bias.mutable_data(framework::make_ddim(axes), place); + Tensor value(x->type()); + value.mutable_data({1}, place); + TensorFromVector(std::vector{static_cast(0)}, ctx.device_context(), + &value); + auto runner = + NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}}); + runner.Run(stream); + bias = &default_bias; + } else { + const_cast(bias)->Resize(framework::make_ddim(axes)); + } + y->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + variance->mutable_data(ctx.GetPlace()); + + auto runner = + NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance}, + {{"begin_norm_axis", begin_norm_axis}, + {"begin_params_axis", begin_norm_axis}, + {"epsilon", epsilon}}); + runner.Run(stream); + // revert shape of scale and bias + // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input + // tensor. + const_cast(scale)->Resize(framework::make_ddim({right})); + const_cast(bias)->Resize(framework::make_ddim({right})); + } +}; + +template +class LayerNormGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto* x = ctx.Input("X"); + const auto& x_dims = x->dims(); + const auto* mean = ctx.Input("Mean"); + const auto* variance = ctx.Input("Variance"); + const auto* scale = ctx.Input("Scale"); + const auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int right = static_cast(matrix_dim[1]); + + std::vector axes; + for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { + axes.push_back(x_dims[i]); + } + + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + // No need to compute any gradient, jusr return + if (!dx && !dscale && !dbias) { + return; + } + + // The rank of mean should be equal to x, required by Ascend. + std::vector new_shape; + for (auto i = 0; i < begin_norm_axis; ++i) { + new_shape.push_back(x_dims[i]); + } + for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { + new_shape.push_back(1); + } + + auto mean_dims = mean->dims(); + const_cast(mean)->Resize(framework::make_ddim({new_shape})); + const_cast(variance)->Resize(framework::make_ddim({new_shape})); + + Tensor default_scale(x->type()); + if (!scale) { + default_scale.mutable_data(framework::make_ddim(axes), place); + Tensor value(x->type()); + value.mutable_data({1}, place); + TensorFromVector(std::vector{static_cast(1.0)}, + ctx.device_context(), &value); + auto runner = + NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); + runner.Run(stream); + scale = &default_scale; + } else { + const_cast(scale)->Resize(framework::make_ddim(axes)); + } + + Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type()); + dx = (dx == nullptr) ? &dx_ : dx; + dscale = (dscale == nullptr) ? &dscale_ : dscale; + dbias = (dbias == nullptr) ? &dbias_ : dbias; + + dscale->Resize(framework::make_ddim(axes)); + dscale->mutable_data(ctx.GetPlace()); + + dbias->Resize(framework::make_ddim(axes)); + dbias->mutable_data(ctx.GetPlace()); + + dx->Resize(x->dims()); + dx->mutable_data(ctx.GetPlace()); + + auto runner = + NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale}, + {*dx, *dscale, *dbias}, {}); + runner.Run(stream); + + const_cast(mean)->Resize(mean_dims); + const_cast(variance)->Resize(mean_dims); + const_cast(scale)->Resize(framework::make_ddim({right})); + dscale->Resize(framework::make_ddim({right})); + dbias->Resize(framework::make_ddim({right})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(layer_norm, ops::LayerNormNPUKernel, + ops::LayerNormNPUKernel); +REGISTER_OP_NPU_KERNEL(layer_norm_grad, ops::LayerNormGradNPUKernel, + ops::LayerNormGradNPUKernel); diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index 97491b22219b2486569823740abdbad00bac7dd2..7ccc6b60c1ef2dc26dab91fcdbf4c0efda3be4a3 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -143,8 +143,8 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, data.push_back(v.data()); num.push_back(v.size()); } - PADDLE_ENFORCE_NPU_SUCCESS( - aclopSetAttrListListInt(attr_, name.c_str(), data.size(), num.data(), data.data())); + PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrListListInt( + attr_, name.c_str(), data.size(), num.data(), data.data())); } else { PADDLE_THROW(platform::errors::Unimplemented( "Can not convert attribubte '%s' to convert to aclopAttr", name)); @@ -234,8 +234,9 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { auto format = ConvertToNpuFormat(tensor.layout()); auto dims = framework::vectorize(tensor.dims()); - VLOG(4) << dtype << " " << dims.size() << " " << dims[0] << "," << dims[1] - << " " << format; + VLOG(4) << "dtype:" << dtype << " " + << "rank:" << dims.size() << " dims:" << tensor.dims() + << " format:" << format; auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format); PADDLE_ENFORCE_NOT_NULL( diff --git a/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..243f1e25e7877a9411d2ee3ea40470df563d82d4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from functools import reduce +from operator import mul +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from test_layer_norm_op import _reference_layer_norm_naive, _reference_layer_norm_grad + +paddle.enable_static() + +SEED = 2021 +EPOCH = 100 + +from op_test import _set_use_system_allocator + +_set_use_system_allocator(False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLayerNormOp(unittest.TestCase): + def setUp(self): + self.use_cudnn = True + self.set_npu() + self.init_dtype() + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def check_forward_backward(self, + shape, + begin_norm_axis, + has_scale=True, + has_bias=True, + y_grad_scale=1.0, + use_mkldnn=False): + def test_with_place(place, + shape, + begin_norm_axis, + use_mkldnn=use_mkldnn): + # attr + epsilon = 0.00001 + x_shape = shape + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + scale_shape = [D] + + np.random.seed(123) + x = np.random.random_sample(x_shape).astype(np.float32) + scale = np.random.random_sample(scale_shape).astype( + np.float32) if has_scale else None + bias = np.random.random_sample(scale_shape).astype( + np.float32) if has_bias else None + y_grad = (np.random.random_sample(x_shape) * + y_grad_scale).astype(np.float32) + + # reference forward & backward + y, mean, variance = _reference_layer_norm_naive( + x, scale, bias, epsilon, begin_norm_axis) + x_grad, scale_grad, bias_grad = _reference_layer_norm_grad( + x, y_grad, scale, bias, mean, variance, begin_norm_axis) + + var_dict = locals() + var_dict['y@GRAD'] = y_grad + var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD'] + if has_scale: + var_names += ['scale'] + if has_bias: + var_names += ['bias'] + ground_truth = {name: var_dict[name] for name in var_names} + + program = fluid.Program() + with fluid.program_guard(program): + block = program.global_block() + for name in ground_truth: + block.create_var( + name=name, + dtype='float32', + shape=ground_truth[name].shape) + inputs = {"X": block.var('x')} + fetch_list = [ + 'y', + 'mean', + 'variance', + 'x@GRAD', + ] + if has_scale: + inputs["Scale"] = block.var('scale') + fetch_list += ['scale@GRAD'] + if has_bias: + inputs["Bias"] = block.var('bias') + fetch_list += ['bias@GRAD'] + layer_norm_op = block.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": block.var('y'), + "Mean": block.var('mean'), # share the same memory + "Variance": + block.var('variance'), # share the same memory + }, + attrs={ + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, + "use_mkldnn": use_mkldnn + }) + # generate backward op_desc + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( + layer_norm_op.desc, set(), []) + grad_op_desc = grad_op_desc_list[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + for var_name in grad_op_desc.output_arg_names(): + block.desc.var(var_name.encode("ascii")) + grad_op_desc.infer_var_type(block.desc) + grad_op_desc.infer_shape(block.desc) + for arg in grad_op_desc.output_arg_names(): + grad_var = block.desc.find_var(arg.encode("ascii")) + grad_var.set_dtype(core.VarDesc.VarType.FP32) + + program._sync_with_cpp() + exe = fluid.Executor(place) + out = exe.run(program, + feed={ + name: var_dict[name] + for name in ['x', 'scale', 'bias', 'y@GRAD'] + }, + fetch_list=fetch_list) + self.__assert_close(y, out[0], "y") + self.__assert_close(mean, out[1], "mean") + self.__assert_close(variance, out[2], "variance", 1e-3) + self.__assert_close(x_grad, out[3], "x_grad", 1e-2) + if has_scale: + self.__assert_close(scale_grad, + out[fetch_list.index('scale@GRAD')], + "scale_grad", 1e-3) + if has_bias: + self.__assert_close(bias_grad, + out[fetch_list.index('bias@GRAD')], + "bias_grad") + + test_with_place(self.place, shape, begin_norm_axis) + + def test_check_forward_backward_with_scale_and_bias(self): + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=True) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=True, + has_bias=False) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=False) + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) + + +if __name__ == '__main__': + unittest.main()