未验证 提交 4640f4be 编写于 作者: S ShenLiang 提交者: GitHub

[OPT] FlashAttention && ModelParallel (#51617)

* fix flash_attention

* Update mp_layers.py
上级 f82da79c
...@@ -67,10 +67,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -67,10 +67,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int num_splits = 0; // 0 for an internal heuristic, which is optimal int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false; bool zero_tensors = false;
std::vector<int64_t> seed_offset_vec; const int64_t* seed_offset_data = seed_offset.data<int64_t>();
phi::TensorToVector<int64_t>(seed_offset, ctx, &seed_offset_vec); uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t seed = seed_offset_vec[0]; uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
uint64_t offset = seed_offset_vec[1];
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q}); DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
...@@ -188,12 +187,10 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -188,12 +187,10 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s = DenseTensor q_t_s, k_t_s, v_t_s;
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
DenseTensor k_t_s = k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size}); v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor cu_seqlens_q; DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k; DenseTensor cu_seqlens_k;
......
...@@ -75,11 +75,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -75,11 +75,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
auto gen = ctx.GetGenerator(); auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32; uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc); auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first; uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second; uint64_t offset = seed_offset_pair.second;
std::vector<int64_t> seed_offset_vec{int64_t(seed), int64_t(offset)}; seed_offset->Resize({2});
phi::TensorFromVector<int64_t>(seed_offset_vec, ctx, seed_offset); auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
...@@ -210,12 +213,10 @@ void FlashAttnKernel(const Context& ctx, ...@@ -210,12 +213,10 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s = DenseTensor q_t_s, k_t_s, v_t_s;
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
DenseTensor k_t_s = k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size}); v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor cu_seqlens_q; DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k; DenseTensor cu_seqlens_k;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.autograd import PyLayer
from paddle.fluid import core from paddle.fluid import core
from paddle.nn import functional as F from paddle.nn import functional as F
...@@ -328,6 +329,17 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -328,6 +329,17 @@ class ColumnParallelLinear(paddle.nn.Layer):
return output return output
class MPScale(PyLayer):
@staticmethod
def forward(ctx, x, mp_degree):
out = paddle.scale(x, 1.0 / mp_degree)
return out
@staticmethod
def backward(ctx, dout):
return dout
class RowParallelLinear(paddle.nn.Layer): class RowParallelLinear(paddle.nn.Layer):
"""Linear layer with mp parallelized(row). """Linear layer with mp parallelized(row).
this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer. this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
...@@ -467,6 +479,7 @@ class RowParallelLinear(paddle.nn.Layer): ...@@ -467,6 +479,7 @@ class RowParallelLinear(paddle.nn.Layer):
from paddle.incubate.nn.functional import fused_linear from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear self.linear = fused_linear
self.fuse_matmul_bias = fuse_matmul_bias
def forward(self, x): def forward(self, x):
if self.input_is_parallel or (not self.is_mp): if self.input_is_parallel or (not self.is_mp):
...@@ -476,16 +489,30 @@ class RowParallelLinear(paddle.nn.Layer): ...@@ -476,16 +489,30 @@ class RowParallelLinear(paddle.nn.Layer):
input_parallel = mp_ops._c_split(x, group=self.model_parallel_group) input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)
if self.is_mp: if self.is_mp:
output_parallel = self.linear( if self.fuse_matmul_bias:
input_parallel, self.weight, name=self._name bias = MPScale.apply(self.bias, self.world_size)
) output_parallel = self.linear(
output_ = mp_ops._mp_allreduce( input_parallel, self.weight, bias, name=self._name
output_parallel, )
group=self.model_parallel_group, output = mp_ops._mp_allreduce(
use_calc_stream=True, output_parallel,
use_model_parallel=True, group=self.model_parallel_group,
) use_calc_stream=True,
output = output_ + self.bias if self.bias is not None else output_ use_model_parallel=True,
)
else:
output_parallel = self.linear(
input_parallel, self.weight, name=self._name
)
output_ = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = (
output_ + self.bias if self.bias is not None else output_
)
else: else:
output = self.linear( output = self.linear(
input_parallel, self.weight, self.bias, name=self._name input_parallel, self.weight, self.bias, name=self._name
......
...@@ -46,15 +46,7 @@ def _c_identity(tensor, group=None): ...@@ -46,15 +46,7 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer): class c_identity_eager(PyLayer):
@staticmethod @staticmethod
def forward(ctx, tensor): def forward(ctx, tensor):
return _legacy_C_ops.c_identity( return tensor
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
...@@ -257,15 +249,7 @@ def _mp_allreduce( ...@@ -257,15 +249,7 @@ def _mp_allreduce(
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
return _legacy_C_ops.c_identity( return dy
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
return mp_allreduce_eager.apply( return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel tensor, group, use_calc_stream, use_model_parallel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册