未验证 提交 b2407af6 编写于 作者: P pangyoki 提交者: GitHub

[NPU] support mixed precision input for npu layer norm (#31847)

* support mixed precision input for npu layer norm

* fix layer_norm npu kernel
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 d1a4c53e
...@@ -21,10 +21,36 @@ namespace operators { ...@@ -21,10 +21,36 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DDim = framework::DDim; using DDim = framework::DDim;
using DataLayout = framework::DataLayout;
template <typename T>
class NormDataType;
template <>
class NormDataType<platform::float16> {
public:
// The scaling param type is float for HALF and FLOAT tensors
using ScalingParamType = const float;
using BatchNormParamType = float;
};
template <>
class NormDataType<float> {
public:
using ScalingParamType = const float;
using BatchNormParamType = float;
};
template <typename T>
using NormDataType = NormDataType<T>;
template <typename T>
using LayerNormParamType = typename NormDataType<T>::BatchNormParamType;
template <typename T> template <typename T>
class LayerNormNPUKernel : public framework::OpKernel<T> { class LayerNormNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using U = LayerNormParamType<T>;
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon"); const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
...@@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> { ...@@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) { for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
axes.push_back(x_dims[i]); axes.push_back(x_dims[i]);
} }
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
...@@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel<T> { ...@@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
} else { } else {
const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes)); const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
} }
// cast scale from LayerNormParamType to T if needed
Tensor cast_scale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
scale->type() == framework::proto::VarType::FP32) {
cast_scale.Resize(scale->dims());
cast_scale.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_scale =
NpuOpRunner("Cast", {*scale}, {cast_scale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_scale.Run(stream);
} else {
cast_scale.ShareDataWith(*scale);
}
// cast bias from LayerNormParamType to T if needed
Tensor cast_bias(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
bias->type() == framework::proto::VarType::FP32) {
cast_bias.Resize(bias->dims());
cast_bias.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_bias =
NpuOpRunner("Cast", {*bias}, {cast_bias},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_bias.Run(stream);
} else {
cast_bias.ShareDataWith(*bias);
}
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
variance->mutable_data<T>(ctx.GetPlace()); // mean should be of U type
Tensor* tmp_mean = mean;
auto runner = Tensor cast_mean(x->type());
NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance}, if (x->type() == framework::proto::VarType::FP16 &&
{{"begin_norm_axis", begin_norm_axis}, (scale->type() == framework::proto::VarType::FP32 ||
{"begin_params_axis", begin_norm_axis}, bias->type() == framework::proto::VarType::FP32)) {
{"epsilon", epsilon}}); cast_mean.Resize(mean->dims());
cast_mean.mutable_data<T>(ctx.GetPlace());
tmp_mean = &cast_mean;
mean->mutable_data<U>(ctx.GetPlace());
} else {
mean->mutable_data<T>(ctx.GetPlace());
}
// same for variance
Tensor* tmp_variance = variance;
Tensor cast_variance(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(scale->type() == framework::proto::VarType::FP32 ||
bias->type() == framework::proto::VarType::FP32)) {
cast_variance.Resize(variance->dims());
cast_variance.mutable_data<T>(ctx.GetPlace());
tmp_variance = &cast_variance;
variance->mutable_data<U>(ctx.GetPlace());
} else {
variance->mutable_data<T>(ctx.GetPlace());
}
auto runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias},
{*y, *tmp_mean, *tmp_variance},
{{"begin_norm_axis", begin_norm_axis},
{"begin_params_axis", begin_norm_axis},
{"epsilon", epsilon}});
runner.Run(stream); runner.Run(stream);
// cast back from FP16 to FP32
if (x->type() == framework::proto::VarType::FP16 &&
mean->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(mean->type());
auto runner_cast_mean =
NpuOpRunner("Cast", {*tmp_mean}, {*mean},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_mean.Run(stream);
}
// same for variance
if (x->type() == framework::proto::VarType::FP16 &&
variance->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(variance->type());
auto runner_cast_variance =
NpuOpRunner("Cast", {*tmp_variance}, {*variance},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_variance.Run(stream);
}
// revert shape of scale and bias // revert shape of scale and bias
// TODO(zhiqiu): better implementation, use tmp tensor to avoid write input // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
// tensor. // tensor.
...@@ -99,6 +203,7 @@ template <typename T> ...@@ -99,6 +203,7 @@ template <typename T>
class LayerNormGradNPUKernel : public framework::OpKernel<T> { class LayerNormGradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using U = LayerNormParamType<T>;
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims(); const auto& x_dims = x->dims();
...@@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> { ...@@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> {
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes)); const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
} }
// cast scale from LayerNormParamType to T if needed
Tensor cast_scale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
scale->type() == framework::proto::VarType::FP32) {
cast_scale.Resize(scale->dims());
cast_scale.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_scale =
NpuOpRunner("Cast", {*scale}, {cast_scale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_scale.Run(stream);
} else {
cast_scale.ShareDataWith(*scale);
}
// cast mean from LayerNormParamType to T if needed
Tensor cast_mean(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
mean->type() == framework::proto::VarType::FP32) {
cast_mean.Resize(mean->dims());
cast_mean.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_mean =
NpuOpRunner("Cast", {*mean}, {cast_mean},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_mean.Run(stream);
} else {
cast_mean.ShareDataWith(*mean);
}
// cast variance from LayerNormParamType to T if needed
Tensor cast_variance(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
variance->type() == framework::proto::VarType::FP32) {
cast_variance.Resize(variance->dims());
cast_variance.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(x->type());
auto runner_cast_variance =
NpuOpRunner("Cast", {*variance}, {cast_variance},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_variance.Run(stream);
} else {
cast_variance.ShareDataWith(*variance);
}
Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type()); Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
dx = (dx == nullptr) ? &dx_ : dx; dx = (dx == nullptr) ? &dx_ : dx;
dscale = (dscale == nullptr) ? &dscale_ : dscale; dscale = (dscale == nullptr) ? &dscale_ : dscale;
dbias = (dbias == nullptr) ? &dbias_ : dbias; dbias = (dbias == nullptr) ? &dbias_ : dbias;
dx->Resize(x->dims());
dx->mutable_data<T>(ctx.GetPlace());
dscale->Resize(framework::make_ddim(axes)); dscale->Resize(framework::make_ddim(axes));
dscale->mutable_data<T>(ctx.GetPlace());
dbias->Resize(framework::make_ddim(axes)); dbias->Resize(framework::make_ddim(axes));
dbias->mutable_data<T>(ctx.GetPlace());
dx->Resize(x->dims()); // dscale should be of U type
dx->mutable_data<T>(ctx.GetPlace()); Tensor* tmp_dscale = dscale;
Tensor cast_dscale(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(mean->type() == framework::proto::VarType::FP32 ||
variance->type() == framework::proto::VarType::FP32)) {
cast_dscale.Resize(dscale->dims());
cast_dscale.mutable_data<T>(ctx.GetPlace());
tmp_dscale = &cast_dscale;
dscale->mutable_data<U>(ctx.GetPlace());
} else {
dscale->mutable_data<T>(ctx.GetPlace());
}
auto runner = // same for dbias
NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale}, Tensor* tmp_dbias = dbias;
{*dx, *dscale, *dbias}, {}); Tensor cast_dbias(x->type());
if (x->type() == framework::proto::VarType::FP16 &&
(mean->type() == framework::proto::VarType::FP32 ||
variance->type() == framework::proto::VarType::FP32)) {
cast_dbias.Resize(dbias->dims());
cast_dbias.mutable_data<T>(ctx.GetPlace());
tmp_dbias = &cast_dbias;
dbias->mutable_data<U>(ctx.GetPlace());
} else {
dbias->mutable_data<T>(ctx.GetPlace());
}
auto runner = NpuOpRunner("LayerNormGrad",
{*dy, *x, cast_variance, cast_mean, cast_scale},
{*dx, *tmp_dscale, *tmp_dbias}, {});
runner.Run(stream); runner.Run(stream);
// cast back from FP16 to FP32
if (x->type() == framework::proto::VarType::FP16 &&
dscale->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(dscale->type());
auto runner_cast_dscale =
NpuOpRunner("Cast", {*tmp_dscale}, {*dscale},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_dscale.Run(stream);
}
// same for dbias
if (x->type() == framework::proto::VarType::FP16 &&
dbias->type() == framework::proto::VarType::FP32) {
auto dst_dtype = ConvertToNpuDtype(dbias->type());
auto runner_cast_dbias =
NpuOpRunner("Cast", {*tmp_dbias}, {*dbias},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_dbias.Run(stream);
}
const_cast<Tensor*>(mean)->Resize(mean_dims); const_cast<Tensor*>(mean)->Resize(mean_dims);
const_cast<Tensor*>(variance)->Resize(mean_dims); const_cast<Tensor*>(variance)->Resize(mean_dims);
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right})); const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
......
...@@ -50,9 +50,13 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -50,9 +50,13 @@ class TestLayerNormOp(unittest.TestCase):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
self.atol = 1e-4
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(
np.allclose(
np.array(tensor).astype(np_array.dtype), np_array, atol=atol),
msg)
def check_forward_backward(self, def check_forward_backward(self,
shape, shape,
...@@ -72,13 +76,13 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -72,13 +76,13 @@ class TestLayerNormOp(unittest.TestCase):
scale_shape = [D] scale_shape = [D]
np.random.seed(123) np.random.seed(123)
x = np.random.random_sample(x_shape).astype(np.float32) x = np.random.random_sample(x_shape).astype(self.dtype)
scale = np.random.random_sample(scale_shape).astype( scale = np.random.random_sample(scale_shape).astype(
np.float32) if has_scale else None np.float32) if has_scale else None
bias = np.random.random_sample(scale_shape).astype( bias = np.random.random_sample(scale_shape).astype(
np.float32) if has_bias else None np.float32) if has_bias else None
y_grad = (np.random.random_sample(x_shape) * y_grad = (np.random.random_sample(x_shape) *
y_grad_scale).astype(np.float32) y_grad_scale).astype(self.dtype)
# reference forward & backward # reference forward & backward
y, mean, variance = _reference_layer_norm_naive( y, mean, variance = _reference_layer_norm_naive(
...@@ -101,7 +105,7 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -101,7 +105,7 @@ class TestLayerNormOp(unittest.TestCase):
for name in ground_truth: for name in ground_truth:
block.create_var( block.create_var(
name=name, name=name,
dtype='float32', dtype=self.dtype,
shape=ground_truth[name].shape) shape=ground_truth[name].shape)
inputs = {"X": block.var('x')} inputs = {"X": block.var('x')}
fetch_list = [ fetch_list = [
...@@ -152,18 +156,18 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -152,18 +156,18 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD'] for name in ['x', 'scale', 'bias', 'y@GRAD']
}, },
fetch_list=fetch_list) fetch_list=fetch_list)
self.__assert_close(y, out[0], "y") self.__assert_close(y, out[0], "y", self.atol)
self.__assert_close(mean, out[1], "mean") self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad", 1e-2) self.__assert_close(x_grad, out[3], "x_grad", 1e-2)
if has_scale: if has_scale:
self.__assert_close(scale_grad, self.__assert_close(scale_grad,
out[fetch_list.index('scale@GRAD')], out[fetch_list.index('scale@GRAD')],
"scale_grad", 1e-3) "scale_grad", 1e-2)
if has_bias: if has_bias:
self.__assert_close(bias_grad, self.__assert_close(bias_grad,
out[fetch_list.index('bias@GRAD')], out[fetch_list.index('bias@GRAD')],
"bias_grad") "bias_grad", self.atol)
test_with_place(self.place, shape, begin_norm_axis) test_with_place(self.place, shape, begin_norm_axis)
...@@ -187,5 +191,13 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -187,5 +191,13 @@ class TestLayerNormOp(unittest.TestCase):
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLayerNormOpFP16(TestLayerNormOp):
def init_dtype(self):
self.dtype = np.float16
self.atol = 1e-2
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册