未验证 提交 2401d48d 编写于 作者: N niuliling123 提交者: GitHub

Update the rope op according to the comments (#54985)

上级 e5d08611
......@@ -15,3 +15,15 @@
func : fused_dropout_add_grad
data_type : out_grad
support_dygraph_mode : true
- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rotary_position_embedding_grad
data_type : out_q_grad
support_dygraph_mode : true
......@@ -88,6 +88,18 @@
data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
......
......@@ -271,17 +271,6 @@
kernel :
func : frobenius_norm_grad
- backward_op : fused_rope_grad
forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rope_grad
data_type : out_q_grad
- backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -407,17 +407,6 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : fused_rope
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rope
data_type : q
backward: fused_rope_grad
- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
......
......@@ -1209,12 +1209,13 @@ void FusedRopeGradInferMeta(const MetaTensor& dout_q,
MetaTensor* dk,
MetaTensor* dv) {
auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_EQ(
input_dims.size(),
4,
phi::errors::InvalidArgument("Input should be a 4-D tensor of format "
"[batch_size, seq_len, num_heads, head_dim],"
"but got %u.",
input_dims.size()));
if (dout_q) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v);
} // namespace phi
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_rope_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
......@@ -21,7 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
......@@ -147,17 +145,15 @@ void FusedRopeGradKernel(const Context& dev_ctx,
num_inputs,
div_c);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_rope_grad,
PD_REGISTER_KERNEL(fused_rotary_position_embedding_grad,
GPU,
ALL_LAYOUT,
phi::FusedRopeGradKernel,
phi::fusion::FusedRopeGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
phi::dtype::bfloat16){};
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_rope_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
......@@ -21,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
......@@ -151,17 +150,14 @@ void FusedRopeKernel(const Context& dev_ctx,
num_inputs,
div_c);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_rope,
PD_REGISTER_KERNEL(fused_rotary_position_embedding,
GPU,
ALL_LAYOUT,
phi::FusedRopeKernel,
phi::fusion::FusedRopeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
phi::dtype::bfloat16){};
......@@ -22,9 +22,11 @@ def fused_rotary_position_embedding(q, k, v):
Fused rotary position embedding.
Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
......@@ -41,7 +43,7 @@ def fused_rotary_position_embedding(q, k, v):
q = paddle.randn([1, 1, 4, 10], dtype='float16')
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out = fused_rotary_position_embedding(q, k, v)
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
"""
if in_dynamic_mode():
return _C_ops.fused_rope(q, k, v)
return _C_ops.fused_rotary_position_embedding(q, k, v)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册