未验证 提交 4640f4be 编写于 作者: S ShenLiang 提交者: GitHub

[OPT] FlashAttention && ModelParallel (#51617)

* fix flash_attention

* Update mp_layers.py
上级 f82da79c
......@@ -67,10 +67,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
std::vector<int64_t> seed_offset_vec;
phi::TensorToVector<int64_t>(seed_offset, ctx, &seed_offset_vec);
uint64_t seed = seed_offset_vec[0];
uint64_t offset = seed_offset_vec[1];
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
......@@ -188,12 +187,10 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
......
......@@ -75,11 +75,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
std::vector<int64_t> seed_offset_vec{int64_t(seed), int64_t(offset)};
phi::TensorFromVector<int64_t>(seed_offset_vec, ctx, seed_offset);
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
......@@ -210,12 +213,10 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle.autograd import PyLayer
from paddle.fluid import core
from paddle.nn import functional as F
......@@ -328,6 +329,17 @@ class ColumnParallelLinear(paddle.nn.Layer):
return output
class MPScale(PyLayer):
@staticmethod
def forward(ctx, x, mp_degree):
out = paddle.scale(x, 1.0 / mp_degree)
return out
@staticmethod
def backward(ctx, dout):
return dout
class RowParallelLinear(paddle.nn.Layer):
"""Linear layer with mp parallelized(row).
this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
......@@ -467,6 +479,7 @@ class RowParallelLinear(paddle.nn.Layer):
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
self.fuse_matmul_bias = fuse_matmul_bias
def forward(self, x):
if self.input_is_parallel or (not self.is_mp):
......@@ -476,16 +489,30 @@ class RowParallelLinear(paddle.nn.Layer):
input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)
if self.is_mp:
output_parallel = self.linear(
input_parallel, self.weight, name=self._name
)
output_ = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = output_ + self.bias if self.bias is not None else output_
if self.fuse_matmul_bias:
bias = MPScale.apply(self.bias, self.world_size)
output_parallel = self.linear(
input_parallel, self.weight, bias, name=self._name
)
output = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
output_parallel = self.linear(
input_parallel, self.weight, name=self._name
)
output_ = mp_ops._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = (
output_ + self.bias if self.bias is not None else output_
)
else:
output = self.linear(
input_parallel, self.weight, self.bias, name=self._name
......
......@@ -46,15 +46,7 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
return tensor
@staticmethod
def backward(ctx, dy):
......@@ -257,15 +249,7 @@ def _mp_allreduce(
@staticmethod
def backward(ctx, dy):
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
return dy
return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册