From 2401d48d6378acc9c68ed6d4d5bb3500902ebdd7 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 3 Jul 2023 10:06:39 +0800 Subject: [PATCH] Update the rope op according to the comments (#54985) --- paddle/phi/api/yaml/fused_backward.yaml | 12 +++++++ paddle/phi/api/yaml/fused_ops.yaml | 12 +++++++ paddle/phi/api/yaml/legacy_backward.yaml | 11 ------- paddle/phi/api/yaml/legacy_ops.yaml | 11 ------- paddle/phi/infermeta/backward.cc | 13 ++++---- paddle/phi/kernels/fused_rope_grad_kernel.h | 31 ------------------- paddle/phi/kernels/fused_rope_kernel.h | 30 ------------------ .../gpu/fused_rope_grad_kernel.cu | 16 ++++------ .../{ => fusion}/gpu/fused_rope_kernel.cu | 14 +++------ .../fused_rotary_position_embedding.py | 12 ++++--- 10 files changed, 49 insertions(+), 113 deletions(-) delete mode 100644 paddle/phi/kernels/fused_rope_grad_kernel.h delete mode 100644 paddle/phi/kernels/fused_rope_kernel.h rename paddle/phi/kernels/{ => fusion}/gpu/fused_rope_grad_kernel.cu (93%) rename paddle/phi/kernels/{ => fusion}/gpu/fused_rope_kernel.cu (94%) diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index f1b460ff5b5..163663f04ad 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -15,3 +15,15 @@ func : fused_dropout_add_grad data_type : out_grad support_dygraph_mode : true + +- backward_op : fused_rotary_position_embedding_grad + forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) + args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + optional : out_k_grad, out_v_grad, k_grad, v_grad + infer_meta : + func : FusedRopeGradInferMeta + kernel : + func : fused_rotary_position_embedding_grad + data_type : out_q_grad + support_dygraph_mode : true diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index f9dc939bf5d..64a5d2bb00a 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -88,6 +88,18 @@ data_type : x optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index +- op : fused_rotary_position_embedding + args : (Tensor q, Tensor k, Tensor v) + output : Tensor(out_q), Tensor(out_k), Tensor(out_v) + infer_meta : + func : FusedRopeInferMeta + optional : k,v, out_k, out_v + kernel : + func : fused_rotary_position_embedding + data_type : q + backward: fused_rotary_position_embedding_grad + support_dygraph_mode : true + - op : generate_sequence_xpu args : (Tensor x, DataType dtype) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 1767eb9aeff..b3f50900c36 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -271,17 +271,6 @@ kernel : func : frobenius_norm_grad -- backward_op : fused_rope_grad - forward: fused_rope (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) - args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) - output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) - optional : out_k_grad, out_v_grad, k_grad, v_grad - infer_meta : - func : FusedRopeGradInferMeta - kernel : - func : fused_rope_grad - data_type : out_q_grad - - backward_op : hardswish_grad forward : hardswish (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index d825514af47..7d51086456a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -407,17 +407,6 @@ optional : skip_update, master_params inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) -- op : fused_rope - args : (Tensor q, Tensor k, Tensor v) - output : Tensor(out_q), Tensor(out_k), Tensor(out_v) - infer_meta : - func : FusedRopeInferMeta - optional : k,v, out_k, out_v - kernel : - func : fused_rope - data_type : q - backward: fused_rope_grad - - op : gaussian args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) output: Tensor(out) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index c784e295a13..1812ead3348 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1209,12 +1209,13 @@ void FusedRopeGradInferMeta(const MetaTensor& dout_q, MetaTensor* dk, MetaTensor* dv) { auto input_dims = dout_q.dims(); - PADDLE_ENFORCE_EQ(input_dims.size(), - 4, - phi::errors::InvalidArgument( - "Input should be a 4-D tensor of format [N, C, H, W] " - "or [N, H, W, C], but got %u.", - input_dims.size())); + PADDLE_ENFORCE_EQ( + input_dims.size(), + 4, + phi::errors::InvalidArgument("Input should be a 4-D tensor of format " + "[batch_size, seq_len, num_heads, head_dim]," + "but got %u.", + input_dims.size())); if (dout_q) { dq->set_dims(dout_q.dims()); dq->set_dtype(dout_q.dtype()); diff --git a/paddle/phi/kernels/fused_rope_grad_kernel.h b/paddle/phi/kernels/fused_rope_grad_kernel.h deleted file mode 100644 index 26e8ed451d6..00000000000 --- a/paddle/phi/kernels/fused_rope_grad_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void FusedRopeGradKernel(const Context& dev_ctx, - const DenseTensor& dout_q, - const paddle::optional& dout_k, - const paddle::optional& dout_v, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv); - -} // namespace phi diff --git a/paddle/phi/kernels/fused_rope_kernel.h b/paddle/phi/kernels/fused_rope_kernel.h deleted file mode 100644 index cdced91dcfd..00000000000 --- a/paddle/phi/kernels/fused_rope_kernel.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void FusedRopeKernel(const Context& dev_ctx, - const DenseTensor& q, - const paddle::optional& k, - const paddle::optional& v, - DenseTensor* out_q, - DenseTensor* out_k, - DenseTensor* out_v); - -} // namespace phi diff --git a/paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu similarity index 93% rename from paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu rename to paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 59db5dbdb9a..a23877b3ab2 100644 --- a/paddle/phi/kernels/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/fused_rope_grad_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/amp_type_traits.h" @@ -21,7 +19,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" namespace phi { - +namespace fusion { template __global__ void VectorizedFusedRopeGradKernel(phi::Array ins_data, int batch_size, @@ -147,17 +145,15 @@ void FusedRopeGradKernel(const Context& dev_ctx, num_inputs, div_c); } + +} // namespace fusion } // namespace phi -PD_REGISTER_KERNEL(fused_rope_grad, +PD_REGISTER_KERNEL(fused_rotary_position_embedding_grad, GPU, ALL_LAYOUT, - phi::FusedRopeGradKernel, + phi::fusion::FusedRopeGradKernel, float, double, phi::dtype::float16, - phi::dtype::bfloat16) { - kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); -} + phi::dtype::bfloat16){}; diff --git a/paddle/phi/kernels/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu similarity index 94% rename from paddle/phi/kernels/gpu/fused_rope_kernel.cu rename to paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index f378a211a35..b155bf3f7d3 100644 --- a/paddle/phi/kernels/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/fused_rope_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/amp_type_traits.h" @@ -21,6 +19,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" namespace phi { +namespace fusion { template __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, @@ -151,17 +150,14 @@ void FusedRopeKernel(const Context& dev_ctx, num_inputs, div_c); } +} // namespace fusion } // namespace phi -PD_REGISTER_KERNEL(fused_rope, +PD_REGISTER_KERNEL(fused_rotary_position_embedding, GPU, ALL_LAYOUT, - phi::FusedRopeKernel, + phi::fusion::FusedRopeKernel, float, double, phi::dtype::float16, - phi::dtype::bfloat16) { - kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); -} + phi::dtype::bfloat16){}; diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index f63b58a793c..ec0fd8fb034 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -22,9 +22,11 @@ def fused_rotary_position_embedding(q, k, v): Fused rotary position embedding. Args: - q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. - k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. - v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + + v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + Returns: out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . @@ -41,7 +43,7 @@ def fused_rotary_position_embedding(q, k, v): q = paddle.randn([1, 1, 4, 10], dtype='float16') k = paddle.randn([1, 1, 4, 10], dtype='float16') v = paddle.randn([1, 1, 4, 10], dtype='float16') - out = fused_rotary_position_embedding(q, k, v) + out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v) """ if in_dynamic_mode(): - return _C_ops.fused_rope(q, k, v) + return _C_ops.fused_rotary_position_embedding(q, k, v) -- GitLab