未验证 提交 2401d48d 编写于 作者: N niuliling123 提交者: GitHub

Update the rope op according to the comments (#54985)

上级 e5d08611
...@@ -15,3 +15,15 @@ ...@@ -15,3 +15,15 @@
func : fused_dropout_add_grad func : fused_dropout_add_grad
data_type : out_grad data_type : out_grad
support_dygraph_mode : true support_dygraph_mode : true
- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rotary_position_embedding_grad
data_type : out_q_grad
support_dygraph_mode : true
...@@ -88,6 +88,18 @@ ...@@ -88,6 +88,18 @@
data_type : x data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true
- op : generate_sequence_xpu - op : generate_sequence_xpu
args : (Tensor x, DataType dtype) args : (Tensor x, DataType dtype)
output : Tensor output : Tensor
......
...@@ -271,17 +271,6 @@ ...@@ -271,17 +271,6 @@
kernel : kernel :
func : frobenius_norm_grad func : frobenius_norm_grad
- backward_op : fused_rope_grad
forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rope_grad
data_type : out_q_grad
- backward_op : hardswish_grad - backward_op : hardswish_grad
forward : hardswish (Tensor x) -> Tensor(out) forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
......
...@@ -407,17 +407,6 @@ ...@@ -407,17 +407,6 @@
optional : skip_update, master_params optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : fused_rope
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rope
data_type : q
backward: fused_rope_grad
- op : gaussian - op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out) output: Tensor(out)
......
...@@ -1209,11 +1209,12 @@ void FusedRopeGradInferMeta(const MetaTensor& dout_q, ...@@ -1209,11 +1209,12 @@ void FusedRopeGradInferMeta(const MetaTensor& dout_q,
MetaTensor* dk, MetaTensor* dk,
MetaTensor* dv) { MetaTensor* dv) {
auto input_dims = dout_q.dims(); auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(), PADDLE_ENFORCE_EQ(
input_dims.size(),
4, 4,
phi::errors::InvalidArgument( phi::errors::InvalidArgument("Input should be a 4-D tensor of format "
"Input should be a 4-D tensor of format [N, C, H, W] " "[batch_size, seq_len, num_heads, head_dim],"
"or [N, H, W, C], but got %u.", "but got %u.",
input_dims.size())); input_dims.size()));
if (dout_q) { if (dout_q) {
dq->set_dims(dout_q.dims()); dq->set_dims(dout_q.dims());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v);
} // namespace phi
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/fused_rope_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
...@@ -21,7 +19,7 @@ ...@@ -21,7 +19,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi { namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2> template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data, __global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size, int batch_size,
...@@ -147,17 +145,15 @@ void FusedRopeGradKernel(const Context& dev_ctx, ...@@ -147,17 +145,15 @@ void FusedRopeGradKernel(const Context& dev_ctx,
num_inputs, num_inputs,
div_c); div_c);
} }
} // namespace fusion
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fused_rope_grad, PD_REGISTER_KERNEL(fused_rotary_position_embedding_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FusedRopeGradKernel, phi::fusion::FusedRopeGradKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) { phi::dtype::bfloat16){};
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/fused_rope_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
...@@ -21,6 +19,7 @@ ...@@ -21,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi { namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2> template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
...@@ -151,17 +150,14 @@ void FusedRopeKernel(const Context& dev_ctx, ...@@ -151,17 +150,14 @@ void FusedRopeKernel(const Context& dev_ctx,
num_inputs, num_inputs,
div_c); div_c);
} }
} // namespace fusion
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fused_rope, PD_REGISTER_KERNEL(fused_rotary_position_embedding,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FusedRopeKernel, phi::fusion::FusedRopeKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) { phi::dtype::bfloat16){};
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -22,9 +22,11 @@ def fused_rotary_position_embedding(q, k, v): ...@@ -22,9 +22,11 @@ def fused_rotary_position_embedding(q, k, v):
Fused rotary position embedding. Fused rotary position embedding.
Args: Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
Returns: Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
...@@ -41,7 +43,7 @@ def fused_rotary_position_embedding(q, k, v): ...@@ -41,7 +43,7 @@ def fused_rotary_position_embedding(q, k, v):
q = paddle.randn([1, 1, 4, 10], dtype='float16') q = paddle.randn([1, 1, 4, 10], dtype='float16')
k = paddle.randn([1, 1, 4, 10], dtype='float16') k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16') v = paddle.randn([1, 1, 4, 10], dtype='float16')
out = fused_rotary_position_embedding(q, k, v) out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
""" """
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.fused_rope(q, k, v) return _C_ops.fused_rotary_position_embedding(q, k, v)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册