未验证 提交 0310945f 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] Support npu op layer_norm and layer_norm_grad (#31310)

* init commit, add layer_norm npu kernel

* fix typo

* add unittest

* add unittest

* fix bug

* fix bug

* refine ut
上级 45765d6e
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T>
class LayerNormNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* variance = ctx.Output<Tensor>("Variance");
const auto& x_dims = x->dims();
std::vector<int> axes;
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int right = static_cast<int>(matrix_dim[1]);
// The shape of scale and bias should be equal to x.shape[begin_norm_axis:],
// required by Ascend.
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
axes.push_back(x_dims[i]);
}
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor default_scale(x->type());
if (!scale) {
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
ctx.device_context(), &value);
auto runner =
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
runner.Run(stream);
scale = &default_scale;
} else {
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
}
Tensor default_bias(x->type());
if (!bias) {
default_bias.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(0)}, ctx.device_context(),
&value);
auto runner =
NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}});
runner.Run(stream);
bias = &default_bias;
} else {
const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
}
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
variance->mutable_data<T>(ctx.GetPlace());
auto runner =
NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance},
{{"begin_norm_axis", begin_norm_axis},
{"begin_params_axis", begin_norm_axis},
{"epsilon", epsilon}});
runner.Run(stream);
// revert shape of scale and bias
// TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
// tensor.
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
const_cast<Tensor*>(bias)->Resize(framework::make_ddim({right}));
}
};
template <typename T>
class LayerNormGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int right = static_cast<int>(matrix_dim[1]);
std::vector<int> axes;
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
axes.push_back(x_dims[i]);
}
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// No need to compute any gradient, jusr return
if (!dx && !dscale && !dbias) {
return;
}
// The rank of mean should be equal to x, required by Ascend.
std::vector<int> new_shape;
for (auto i = 0; i < begin_norm_axis; ++i) {
new_shape.push_back(x_dims[i]);
}
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
new_shape.push_back(1);
}
auto mean_dims = mean->dims();
const_cast<Tensor*>(mean)->Resize(framework::make_ddim({new_shape}));
const_cast<Tensor*>(variance)->Resize(framework::make_ddim({new_shape}));
Tensor default_scale(x->type());
if (!scale) {
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
ctx.device_context(), &value);
auto runner =
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
runner.Run(stream);
scale = &default_scale;
} else {
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
}
Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
dx = (dx == nullptr) ? &dx_ : dx;
dscale = (dscale == nullptr) ? &dscale_ : dscale;
dbias = (dbias == nullptr) ? &dbias_ : dbias;
dscale->Resize(framework::make_ddim(axes));
dscale->mutable_data<T>(ctx.GetPlace());
dbias->Resize(framework::make_ddim(axes));
dbias->mutable_data<T>(ctx.GetPlace());
dx->Resize(x->dims());
dx->mutable_data<T>(ctx.GetPlace());
auto runner =
NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale},
{*dx, *dscale, *dbias}, {});
runner.Run(stream);
const_cast<Tensor*>(mean)->Resize(mean_dims);
const_cast<Tensor*>(variance)->Resize(mean_dims);
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
dscale->Resize(framework::make_ddim({right}));
dbias->Resize(framework::make_ddim({right}));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(layer_norm, ops::LayerNormNPUKernel<float>,
ops::LayerNormNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(layer_norm_grad, ops::LayerNormGradNPUKernel<float>,
ops::LayerNormGradNPUKernel<plat::float16>);
...@@ -143,8 +143,8 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, ...@@ -143,8 +143,8 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
data.push_back(v.data()); data.push_back(v.data());
num.push_back(v.size()); num.push_back(v.size());
} }
PADDLE_ENFORCE_NPU_SUCCESS( PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrListListInt(
aclopSetAttrListListInt(attr_, name.c_str(), data.size(), num.data(), data.data())); attr_, name.c_str(), data.size(), num.data(), data.data()));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Can not convert attribubte '%s' to convert to aclopAttr", name)); "Can not convert attribubte '%s' to convert to aclopAttr", name));
...@@ -234,8 +234,9 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { ...@@ -234,8 +234,9 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
auto format = ConvertToNpuFormat(tensor.layout()); auto format = ConvertToNpuFormat(tensor.layout());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
VLOG(4) << dtype << " " << dims.size() << " " << dims[0] << "," << dims[1] VLOG(4) << "dtype:" << dtype << " "
<< " " << format; << "rank:" << dims.size() << " dims:" << tensor.dims()
<< " format:" << format;
auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format); auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from functools import reduce
from operator import mul
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from test_layer_norm_op import _reference_layer_norm_naive, _reference_layer_norm_grad
paddle.enable_static()
SEED = 2021
EPOCH = 100
from op_test import _set_use_system_allocator
_set_use_system_allocator(False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLayerNormOp(unittest.TestCase):
def setUp(self):
self.use_cudnn = True
self.set_npu()
self.init_dtype()
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_forward_backward(self,
shape,
begin_norm_axis,
has_scale=True,
has_bias=True,
y_grad_scale=1.0,
use_mkldnn=False):
def test_with_place(place,
shape,
begin_norm_axis,
use_mkldnn=use_mkldnn):
# attr
epsilon = 0.00001
x_shape = shape
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
scale_shape = [D]
np.random.seed(123)
x = np.random.random_sample(x_shape).astype(np.float32)
scale = np.random.random_sample(scale_shape).astype(
np.float32) if has_scale else None
bias = np.random.random_sample(scale_shape).astype(
np.float32) if has_bias else None
y_grad = (np.random.random_sample(x_shape) *
y_grad_scale).astype(np.float32)
# reference forward & backward
y, mean, variance = _reference_layer_norm_naive(
x, scale, bias, epsilon, begin_norm_axis)
x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
x, y_grad, scale, bias, mean, variance, begin_norm_axis)
var_dict = locals()
var_dict['y@GRAD'] = y_grad
var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD']
if has_scale:
var_names += ['scale']
if has_bias:
var_names += ['bias']
ground_truth = {name: var_dict[name] for name in var_names}
program = fluid.Program()
with fluid.program_guard(program):
block = program.global_block()
for name in ground_truth:
block.create_var(
name=name,
dtype='float32',
shape=ground_truth[name].shape)
inputs = {"X": block.var('x')}
fetch_list = [
'y',
'mean',
'variance',
'x@GRAD',
]
if has_scale:
inputs["Scale"] = block.var('scale')
fetch_list += ['scale@GRAD']
if has_bias:
inputs["Bias"] = block.var('bias')
fetch_list += ['bias@GRAD']
layer_norm_op = block.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": block.var('y'),
"Mean": block.var('mean'), # share the same memory
"Variance":
block.var('variance'), # share the same memory
},
attrs={
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
"use_mkldnn": use_mkldnn
})
# generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
layer_norm_op.desc, set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
program._sync_with_cpp()
exe = fluid.Executor(place)
out = exe.run(program,
feed={
name: var_dict[name]
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
self.__assert_close(y, out[0], "y")
self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad", 1e-2)
if has_scale:
self.__assert_close(scale_grad,
out[fetch_list.index('scale@GRAD')],
"scale_grad", 1e-3)
if has_bias:
self.__assert_close(bias_grad,
out[fetch_list.index('bias@GRAD')],
"bias_grad")
test_with_place(self.place, shape, begin_norm_axis)
def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=False,
has_bias=True)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=True,
has_bias=False)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=False,
has_bias=False)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册