未验证 提交 a215c46a 编写于 作者: N niuliling123 提交者: GitHub

Add fused_rope forward op (#54351)

* style

* more

* update ctest

* Update legacy_backward.yaml

* Update legacy_ops.yaml

* Update legacy_ops.yaml

* update

* update

* update for move
上级 7c89b972
......@@ -302,6 +302,17 @@
kernel :
func : frobenius_norm_grad
- backward_op : fused_rope_grad
forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rope_grad
data_type : out_q_grad
- backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -423,6 +423,17 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : fused_rope
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rope
data_type : q
backward: fused_rope_grad
- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
......
......@@ -1202,4 +1202,31 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (dout_q) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
}
} // namespace phi
......@@ -184,6 +184,13 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
MetaTensor* x_grad,
MetaTensor* y_grad);
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -3484,6 +3484,33 @@ void FusedConvInferMeta(const MetaTensor& input,
config);
}
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
auto input_dims = q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (q) {
out_q->set_dims(q.dims());
out_q->set_dtype(q.dtype());
}
if (k) {
out_k->set_dims(k.dims());
out_k->set_dtype(k.dtype());
}
if (v) {
out_v->set_dims(v.dims());
out_v->set_dtype(v.dtype());
}
}
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
......
......@@ -673,4 +673,11 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_rope_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
dq->Resize(dout_q.dims());
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
auto head_dim = dout_q.dims()[3];
auto seq_len = dout_q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));
constexpr const int vec_size = 2;
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
int num_inputs = 0;
if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
dk->Resize(dout_q.dims());
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
}
if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
dv->Resize(dout_q.dims());
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
VectorizedFusedRopeGradKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace phi
PD_REGISTER_KERNEL(fused_rope_grad,
GPU,
ALL_LAYOUT,
phi::FusedRopeGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_rope_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0;
result[pr_index] -= sin_value[pr_index] * p1;
result[ls_index] = sin_value[ls_index] * p0;
result[ls_index] += cos_value[ls_index] * p1;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
int numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
out_q->Resize(q.dims());
// small size for broadcast
auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2];
auto head_dim = q.dims()[3];
auto seq_len = q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));
constexpr const int vec_size = 2;
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
int num_inputs = 0;
if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k);
out_k->Resize(q.dims());
ins_data[1] = k->data<T>();
outs_data[1] = out_k->data<T>();
num_inputs++;
}
if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v);
out_v->Resize(q.dims());
ins_data[2] = v->data<T>();
outs_data[2] = out_v->data<T>();
num_inputs++;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace phi
PD_REGISTER_KERNEL(fused_rope,
GPU,
ALL_LAYOUT,
phi::FusedRopeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -20,6 +20,7 @@ from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_gate_attention import fused_gate_attention
from .fused_rotary_position_embedding import fused_rotary_position_embedding
__all__ = [
......@@ -31,4 +32,5 @@ __all__ = [
'fused_bias_dropout_residual_layer_norm',
'fused_ec_moe',
'fused_dropout_add',
'fused_rotary_position_embedding',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k, v):
r"""
Fused rotary position embedding.
Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_rotary_position_embedding
q = paddle.randn([1, 1, 4, 10], dtype='float16')
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out = fused_rotary_position_embedding(q, k, v)
"""
if in_dynamic_mode():
return _C_ops.fused_rope(q, k, v)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.incubate.nn.functional import fused_rotary_position_embedding
def deal_qkv(init_q, init_k, init_v):
perm = [0, 2, 1, 3]
q = paddle.transpose(x=init_q, perm=perm)
k = paddle.transpose(x=init_k, perm=perm)
v = paddle.transpose(x=init_v, perm=perm)
return q, k, v
def mult_qkv(value, cos_tensor, sin_tensor):
rotate_half_q = paddle.reshape(
paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1),
paddle.shape(value),
)
query = paddle.add(
paddle.multiply(value, cos_tensor),
paddle.multiply(rotate_half_q, sin_tensor),
)
return query
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
q, k, v = deal_qkv(init_q, init_k, init_v)
pos_seq = paddle.arange(0, q.shape[2], 1, dtype="float32")
indices = paddle.arange(0, q.shape[3], 2, dtype="float32")
indices = 1 / 10000 ** (indices / q.shape[3])
sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)
sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
cos_cos = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
numpy_array = sinusoid_inp.numpy()
iter_array = np.nditer(numpy_array)
i = 0
for value in iter_array:
sin_sin[i * 2] = -1 * np.sin(value)
cos_cos[i * 2 + 0] = np.cos(value)
sin_sin[i * 2 + 1] = np.sin(value)
cos_cos[i * 2 + 1] = np.cos(value)
i += 1
sin_tensor = paddle.reshape(
paddle.to_tensor(sin_sin, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
)
cos_tensor = paddle.reshape(
paddle.to_tensor(cos_cos, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
)
query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)
r_query, r_key, r_value = deal_qkv(query, key, value)
return r_query, r_key, r_value
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not compiled with CUDA ",
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self):
self.shape = [1, 16, 1, 16]
self.dtype = 'float32'
self.training = True
self.seed = 1203
def get_paddle_tensor(self):
tmp = paddle.randn(self.shape, self.dtype)
tmp.stop_gradient = False
return tmp
def get_forward_backward(self, rope_function, seed):
paddle.disable_static()
paddle.seed(seed)
fw = []
bw = []
tensor_q = self.get_paddle_tensor()
tensor_k = self.get_paddle_tensor()
tensor_v = self.get_paddle_tensor()
out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v)
fw.append(out_q)
fw.append(out_k)
fw.append(out_v)
out_gq = paddle.randn(out_q.shape, self.dtype)
out_gk = paddle.randn(out_q.shape, self.dtype)
out_gv = paddle.randn(out_q.shape, self.dtype)
paddle.autograd.backward(
[out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
)
bw.append(tensor_q)
bw.append(tensor_k)
bw.append(tensor_v)
return fw, bw
def test_fused_dropout_add(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding, seed=self.seed
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册