未验证 提交 12547fb4 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Refine FusedNorm comment (#56305)

* refine static op return val
上级 163152aa
...@@ -102,6 +102,17 @@ ...@@ -102,6 +102,17 @@
optional : bias, dequant_scales, shift, smooth optional : bias, dequant_scales, shift, smooth
support_dygraph_mode : true support_dygraph_mode : true
- op : fused_bias_residual_layernorm
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance)
infer_meta :
func : FusedLayerNormInferMeta
kernel :
func : fused_bias_residual_layernorm
data_type : x
optional : bias, residual, norm_weight, norm_bias, residual_out
support_dygraph_mode : true
- op : fused_dropout_add - op : fused_dropout_add
args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false) args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false)
optional : seed_tensor optional : seed_tensor
......
...@@ -1017,16 +1017,6 @@ ...@@ -1017,16 +1017,6 @@
data_type : dtype data_type : dtype
backend : place backend : place
- op : fused_bias_residual_layernorm
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance)
infer_meta :
func : FusedLayerNormInferMeta
kernel :
func : fused_bias_residual_layernorm
data_type : x
optional : bias, residual, norm_weight, norm_bias, residual_out
- op : gather - op : gather
args : (Tensor x, Tensor index, Scalar axis=0) args : (Tensor x, Tensor index, Scalar axis=0)
output : Tensor(out) output : Tensor(out)
......
...@@ -1535,6 +1535,7 @@ void FusedLayerNormInferMeta(const MetaTensor& x, ...@@ -1535,6 +1535,7 @@ void FusedLayerNormInferMeta(const MetaTensor& x,
rows *= x.dims()[i]; rows *= x.dims()[i];
} }
if (norm_weight) {
PADDLE_ENFORCE_EQ(normalized_dims, PADDLE_ENFORCE_EQ(normalized_dims,
norm_weight.dims()[0], norm_weight.dims()[0],
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -1544,6 +1545,7 @@ void FusedLayerNormInferMeta(const MetaTensor& x, ...@@ -1544,6 +1545,7 @@ void FusedLayerNormInferMeta(const MetaTensor& x,
"of Weight is [%d]", "of Weight is [%d]",
normalized_dims, normalized_dims,
norm_weight.dims()[0])); norm_weight.dims()[0]));
}
auto out_dims = phi::make_ddim(x_dims_vec); auto out_dims = phi::make_ddim(x_dims_vec);
......
...@@ -34,7 +34,6 @@ limitations under the License. ...@@ -34,7 +34,6 @@ limitations under the License.
// The following code modified from OneFlow's implementation, and change to use // The following code modified from OneFlow's implementation, and change to use
// single Pass algorithm. Support Int8 quant, dequant Load/Store implementation. // single Pass algorithm. Support Int8 quant, dequant Load/Store implementation.
#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h"
#include <assert.h> #include <assert.h>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedLayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& residual,
const paddle::optional<DenseTensor>& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out,
DenseTensor* mean,
DenseTensor* variance);
} // namespace fusion
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -122,4 +122,4 @@ def fused_layer_norm( ...@@ -122,4 +122,4 @@ def fused_layer_norm(
}, },
outputs=outputs_dict, outputs=outputs_dict,
) )
return out return (out, residual_out) if residual is not None else out
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -34,6 +34,10 @@ def quant_helper( ...@@ -34,6 +34,10 @@ def quant_helper(
) )
def naive_residual_bias_add(x, residual, bias, residual_alpha):
return x + residual_alpha * residual + bias
def naive_layer_norm(x, gamma, beta, epsilon): def naive_layer_norm(x, gamma, beta, epsilon):
x_float = paddle.cast(x, dtype=paddle.float32) x_float = paddle.cast(x, dtype=paddle.float32)
mean = paddle.mean(x_float, axis=-1, keepdim=True) mean = paddle.mean(x_float, axis=-1, keepdim=True)
...@@ -66,13 +70,14 @@ def naive_residual_biasadd_layer_norm( ...@@ -66,13 +70,14 @@ def naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, epsilon, residual_alpha x, residual, bias, gamma, beta, epsilon, residual_alpha
): ):
x = x + residual * residual_alpha + bias x = x + residual * residual_alpha + bias
residual_out = x
mean = paddle.mean(x, axis=-1, keepdim=True) mean = paddle.mean(x, axis=-1, keepdim=True)
var = paddle.var(x, axis=-1, keepdim=True) var = paddle.var(x, axis=-1, keepdim=True)
sqrt_var = paddle.rsqrt(var + epsilon) sqrt_var = paddle.rsqrt(var + epsilon)
out = ((x - mean) * sqrt_var) * paddle.cast(gamma, x.dtype) + paddle.cast( out = ((x - mean) * sqrt_var) * paddle.cast(gamma, x.dtype) + paddle.cast(
beta, x.dtype beta, x.dtype
) )
return out return out, residual_out
def naive_residual_biasadd_layer_norm_int8( def naive_residual_biasadd_layer_norm_int8(
...@@ -88,13 +93,13 @@ def naive_residual_biasadd_layer_norm_int8( ...@@ -88,13 +93,13 @@ def naive_residual_biasadd_layer_norm_int8(
quant_max_bound, quant_max_bound,
quant_min_bound, quant_min_bound,
): ):
out = naive_residual_biasadd_layer_norm( out, residual_out = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, epsilon, residual_alpha x, residual, bias, gamma, beta, epsilon, residual_alpha
) )
out = quant_helper( out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
) )
return out return out, residual_out
@unittest.skipIf( @unittest.skipIf(
...@@ -165,6 +170,29 @@ class TestlayernormOp(unittest.TestCase): ...@@ -165,6 +170,29 @@ class TestlayernormOp(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out return paddle_layernorm_out, paddle_naive_layernorm_out
def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x,
None,
None,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
residual_alpha=self.residual_alpha,
)
paddle_naive_residual_out = naive_residual_bias_add(
x, residual, bias, self.residual_alpha
)
paddle.enable_static()
return (paddle_layernorm_out, paddle_naive_residual_out)
def check_residual_bias_layernorm( def check_residual_bias_layernorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
): ):
...@@ -186,11 +214,18 @@ class TestlayernormOp(unittest.TestCase): ...@@ -186,11 +214,18 @@ class TestlayernormOp(unittest.TestCase):
residual_alpha=self.residual_alpha, residual_alpha=self.residual_alpha,
) )
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm( (
paddle_naive_layernorm_out,
paddle_naive_residual_out,
) = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha
) )
paddle.enable_static() paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out return (
paddle_layernorm_out,
paddle_naive_layernorm_out,
paddle_naive_residual_out,
)
def check_residual_bias_layernorm_int8( def check_residual_bias_layernorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
...@@ -217,7 +252,10 @@ class TestlayernormOp(unittest.TestCase): ...@@ -217,7 +252,10 @@ class TestlayernormOp(unittest.TestCase):
quant_min_bound=self.quant_min_bound, quant_min_bound=self.quant_min_bound,
) )
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8( (
paddle_naive_layernorm_out,
paddle_naive_residual_out,
) = naive_residual_biasadd_layer_norm_int8(
x, x,
residual, residual,
bias, bias,
...@@ -231,7 +269,28 @@ class TestlayernormOp(unittest.TestCase): ...@@ -231,7 +269,28 @@ class TestlayernormOp(unittest.TestCase):
self.quant_min_bound, self.quant_min_bound,
) )
paddle.enable_static() paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out return (
paddle_layernorm_out,
paddle_naive_layernorm_out,
paddle_naive_residual_out,
)
def test_residual_bias_add(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_residual_bias_out,
paddle_naive_residual_bias_out,
) = self.check_residual_bias_add(
self.x_np, self.residual_np, self.bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_residual_bias_out[0].numpy(),
paddle_naive_residual_bias_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_layernorm_fp16(self): def test_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
...@@ -266,6 +325,7 @@ class TestlayernormOp(unittest.TestCase): ...@@ -266,6 +325,7 @@ class TestlayernormOp(unittest.TestCase):
( (
paddle_layernorm, paddle_layernorm,
paddle_naive_layernorm, paddle_naive_layernorm,
paddle_naive_residual_out,
) = self.check_residual_bias_layernorm( ) = self.check_residual_bias_layernorm(
self.x_np, self.x_np,
self.norm_weight_np, self.norm_weight_np,
...@@ -282,12 +342,20 @@ class TestlayernormOp(unittest.TestCase): ...@@ -282,12 +342,20 @@ class TestlayernormOp(unittest.TestCase):
atol=1e-3, atol=1e-3,
) )
np.testing.assert_allclose(
paddle_layernorm[1].numpy(),
paddle_naive_residual_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_layernorm_int8(self): def test_residual_bias_add_layernorm_int8(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
( (
paddle_layernorm, paddle_layernorm,
paddle_naive_layernorm, paddle_naive_layernorm,
paddle_naive_residual_out,
) = self.check_residual_bias_layernorm_int8( ) = self.check_residual_bias_layernorm_int8(
self.x_np, self.x_np,
self.norm_weight_np, self.norm_weight_np,
...@@ -304,6 +372,13 @@ class TestlayernormOp(unittest.TestCase): ...@@ -304,6 +372,13 @@ class TestlayernormOp(unittest.TestCase):
atol=2, atol=2,
) )
np.testing.assert_allclose(
paddle_layernorm[1].numpy(),
paddle_naive_residual_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA " not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
...@@ -367,7 +442,7 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -367,7 +442,7 @@ class TestlayernormStaticOp(unittest.TestCase):
}, },
fetch_list=[outs], fetch_list=[outs],
) )
return out_s[0], paddle_naive_layernorm_out return out_s, paddle_naive_layernorm_out
def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype): def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static() paddle.disable_static()
...@@ -417,7 +492,56 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -417,7 +492,56 @@ class TestlayernormStaticOp(unittest.TestCase):
}, },
fetch_list=[outs], fetch_list=[outs],
) )
return out_s[0], paddle_naive_layernorm_out return out_s, paddle_naive_layernorm_out
def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_residual_out = naive_residual_bias_add(
x, residual, bias, self.residual_alpha
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
residual_static = paddle.static.data(
name="residual_static",
shape=[self.batch, self.cols],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
None,
None,
self.epsilon,
begin_norm_axis=1,
bias=bias_static,
residual=residual_static,
residual_alpha=self.residual_alpha,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s, paddle_naive_residual_out
def check_residual_bias_layernorm( def check_residual_bias_layernorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
...@@ -429,7 +553,10 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -429,7 +553,10 @@ class TestlayernormStaticOp(unittest.TestCase):
residual = paddle.to_tensor(residual_np.astype(dtype)) residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype)) bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm( (
paddle_naive_layernorm_out,
paddle_naive_residual_out,
) = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha
) )
paddle.enable_static() paddle.enable_static()
...@@ -474,7 +601,7 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -474,7 +601,7 @@ class TestlayernormStaticOp(unittest.TestCase):
}, },
fetch_list=[outs], fetch_list=[outs],
) )
return out_s[0], paddle_naive_layernorm_out return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out
def check_residual_bias_layernorm_int8( def check_residual_bias_layernorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
...@@ -486,7 +613,10 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -486,7 +613,10 @@ class TestlayernormStaticOp(unittest.TestCase):
residual = paddle.to_tensor(residual_np.astype(dtype)) residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype)) bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8( (
paddle_naive_layernorm_out,
paddle_naive_residual_out,
) = naive_residual_biasadd_layer_norm_int8(
x, x,
residual, residual,
bias, bias,
...@@ -545,7 +675,7 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -545,7 +675,7 @@ class TestlayernormStaticOp(unittest.TestCase):
}, },
fetch_list=[outs], fetch_list=[outs],
) )
return out_s[0], paddle_naive_layernorm_out return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out
def test_layernorm_fp16(self): def test_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
...@@ -555,7 +685,7 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -555,7 +685,7 @@ class TestlayernormStaticOp(unittest.TestCase):
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_layernorm, paddle_layernorm[0],
paddle_naive_layernorm.numpy(), paddle_naive_layernorm.numpy(),
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
...@@ -568,18 +698,39 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -568,18 +698,39 @@ class TestlayernormStaticOp(unittest.TestCase):
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_layernorm, paddle_layernorm[0],
paddle_naive_layernorm.numpy(), paddle_naive_layernorm.numpy(),
rtol=2, rtol=2,
atol=2, atol=2,
) )
def test_residual_bias_add(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_layernorm,
paddle_naive_residual_out,
) = self.check_residual_bias_add(
self.x_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_layernorm[0],
paddle_naive_residual_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_layernorm_fp16(self): def test_residual_bias_add_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
( (
paddle_layernorm, paddle_layernorm,
paddle_naive_layernorm, paddle_naive_layernorm,
paddle_naive_residual_out,
) = self.check_residual_bias_layernorm( ) = self.check_residual_bias_layernorm(
self.x_np, self.x_np,
self.norm_weight_np, self.norm_weight_np,
...@@ -590,18 +741,26 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -590,18 +741,26 @@ class TestlayernormStaticOp(unittest.TestCase):
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_layernorm, paddle_layernorm[0],
paddle_naive_layernorm.numpy(), paddle_naive_layernorm.numpy(),
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
) )
np.testing.assert_allclose(
paddle_layernorm[1],
paddle_naive_residual_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_layernorm_int8(self): def test_residual_bias_add_layernorm_int8(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
( (
paddle_layernorm, paddle_layernorm,
paddle_naive_layernorm, paddle_naive_layernorm,
paddle_naive_residual_out,
) = self.check_residual_bias_layernorm_int8( ) = self.check_residual_bias_layernorm_int8(
self.x_np, self.x_np,
self.norm_weight_np, self.norm_weight_np,
...@@ -612,12 +771,19 @@ class TestlayernormStaticOp(unittest.TestCase): ...@@ -612,12 +771,19 @@ class TestlayernormStaticOp(unittest.TestCase):
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_layernorm, paddle_layernorm[0],
paddle_naive_layernorm.numpy(), paddle_naive_layernorm.numpy(),
rtol=2, rtol=2,
atol=2, atol=2,
) )
np.testing.assert_allclose(
paddle_layernorm[1],
paddle_naive_residual_out.numpy(),
rtol=2,
atol=2,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -487,7 +487,6 @@ class TestRMSNormStaticOp(unittest.TestCase): ...@@ -487,7 +487,6 @@ class TestRMSNormStaticOp(unittest.TestCase):
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
) )
print("1111")
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_rmsnorm, paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(), paddle_naive_rmsnorm.numpy(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册