未验证 提交 25a0b46d 编写于 作者: D duanyanhui 提交者: GitHub

optimize softmax_mask_fuse (#56877)

上级 d38cd6ce
......@@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
input_name != "X") {
return input;
}
if (op_name == "fused_softmax_mask" && input_name == "Mask" &&
input.dtype() == phi::DataType::FLOAT32) {
return input;
}
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
......
......@@ -22,9 +22,9 @@ namespace phi {
namespace fusion {
// T == fp16
template <typename T, int pow2_index>
template <typename T, typename MT, int pow2_index>
__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
const T* mask_data,
const MT* mask_data,
T* y_data,
int batch_count,
int key_seq_len) {
......@@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
// using float for all inter compute
float data[kLocalBatchSize][kLocalIterations];
T temp_data[kOneLoadingCounts];
T temp_mask[kOneLoadingCounts];
MT temp_mask[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
......@@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& mask,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* mask_data = mask.data<T>();
auto* y_data = dev_ctx.template Alloc<T>(out);
auto x_dim = x.dims();
......@@ -226,47 +225,91 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index
if (mask.dtype() == x.dtype()) {
auto* mask_data = mask.data<T>();
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGPUKernel<T, T, 5><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGPUKernel<T, T, 6><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGPUKernel<T, T, 7><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGPUKernel<T, T, 8><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGPUKernel<T, T, 9><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGPUKernel<T, T, 10><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGPUKernel<T, T, 11><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGPUKernel<T, T, 12><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGPUKernel<T, T, 13><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
default:
break;
}
} else if (mask.dtype() == phi::DataType::FLOAT32) {
auto* mask_data = mask.data<float>();
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGPUKernel<T, 5><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 5><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGPUKernel<T, 6><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 6><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGPUKernel<T, 7><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 7><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGPUKernel<T, 8><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 8><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGPUKernel<T, 9><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 9><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGPUKernel<T, 10><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 10><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGPUKernel<T, 11><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 11><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGPUKernel<T, 12><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 12><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGPUKernel<T, 13><<<blocks, threads, 0, stream>>>(
SoftmaxMaskFuseGPUKernel<T, float, 13><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
default:
break;
}
}
}
} // namespace fusion
......
......@@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest):
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSoftmaxMaskFuseOp01(OpTest):
def setUp(self):
self.op_type = "fused_softmax_mask"
self.python_api = paddle.incubate.softmax_mask_fuse
x = np.random.random((1, 1, 8, 32)).astype("float16")
mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32")
mask_input = np.where(mask == 1, -10000.0, mask)
self.inputs = {'X': x, 'Mask': mask_input}
rst = _get_softmax(x, mask_input)
self.outputs = {'Out': rst}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册