未验证 提交 97d3d6ee 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Integrate rmsnorm kernel (#54998)

* add rmsnorm kernel
* add static graph test
* fix round type
* use alignas to avoid msvc compile error
* remove redundant headerfile to avoid rocm compile error
* fix rocm compile not found cub
* Add document
上级 852d7a12
......@@ -1994,6 +1994,16 @@
data_type : x
backward : reverse_grad
- op : rms_norm
args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out)
infer_meta :
func : RmsNormInferMeta
kernel :
func : rms_norm
data_type : x
optional : bias
- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)
......
......@@ -3137,6 +3137,38 @@ void Unpool3dInferMeta(const MetaTensor& x,
}
}
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
......@@ -479,4 +479,11 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);
} // namespace phi
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int begin_norm_axis,
DenseTensor* out);
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output);
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output);
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output);
} // namespace phi
......@@ -21,6 +21,7 @@ from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_gate_attention import fused_gate_attention
from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .rms_norm import rms_norm
__all__ = [
......@@ -33,4 +34,5 @@ __all__ = [
'fused_ec_moe',
'fused_dropout_add',
'fused_rotary_position_embedding',
"rms_norm",
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
def rms_norm(x, weight, bias, epsilon, begin_norm_axis):
r"""
Apply RMSNorm kernel.
Args:
x (Tensor): the input Tensor..
weight (Tensor): the weight Tensor to affine output.
bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0.
begin_norm_axis (int): the begin axis to normalize.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
epsilon = 1e-6
paddle_rmsnorm = paddle.incubate.nn.functional.rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dygraph_mode():
return _C_ops.rms_norm(x, weight, bias, epsilon, begin_norm_axis)
helper = LayerHelper('rms_norm', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='rms_norm',
inputs={'x': x, 'weight': weight, 'bias': bias},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
outputs={'out': out},
)
return out
......@@ -75,6 +75,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer)
......@@ -154,6 +155,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import core
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestRMSNormOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
batch = 32
cols = 256
self.x_np = np.random.random([batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.epsilon = 1e-6
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.rms_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1
)
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-03,
atol=1e-3,
)
def test_rmsnorm_fp32(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestRMSNormStaticOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
self.batch = 32
self.cols = 256
self.x_np = np.random.random([self.batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.epsilon = 1e-6
self.place = paddle.CUDAPlace(0)
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_rmsnorm_fp32(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
)
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册