未验证 提交 25a0b46d 编写于 作者: D duanyanhui 提交者: GitHub

optimize softmax_mask_fuse (#56877)

上级 d38cd6ce
...@@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, ...@@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
input_name != "X") { input_name != "X") {
return input; return input;
} }
if (op_name == "fused_softmax_mask" && input_name == "Mask" &&
input.dtype() == phi::DataType::FLOAT32) {
return input;
}
if (dst_dtype == phi::DataType::FLOAT16) { if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") { if (op_name == "run_program") {
return input; return input;
......
...@@ -22,9 +22,9 @@ namespace phi { ...@@ -22,9 +22,9 @@ namespace phi {
namespace fusion { namespace fusion {
// T == fp16 // T == fp16
template <typename T, int pow2_index> template <typename T, typename MT, int pow2_index>
__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
const T* mask_data, const MT* mask_data,
T* y_data, T* y_data,
int batch_count, int batch_count,
int key_seq_len) { int key_seq_len) {
...@@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, ...@@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
// using float for all inter compute // using float for all inter compute
float data[kLocalBatchSize][kLocalIterations]; float data[kLocalBatchSize][kLocalIterations];
T temp_data[kOneLoadingCounts]; T temp_data[kOneLoadingCounts];
T temp_mask[kOneLoadingCounts]; MT temp_mask[kOneLoadingCounts];
#pragma unroll #pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) { for (int i = 0; i < kLocalBatchSize; ++i) {
...@@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, ...@@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& mask, const DenseTensor& mask,
DenseTensor* out) { DenseTensor* out) {
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* mask_data = mask.data<T>();
auto* y_data = dev_ctx.template Alloc<T>(out); auto* y_data = dev_ctx.template Alloc<T>(out);
auto x_dim = x.dims(); auto x_dim = x.dims();
...@@ -226,46 +225,90 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, ...@@ -226,46 +225,90 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index if (mask.dtype() == x.dtype()) {
switch (pow2_index) { auto* mask_data = mask.data<T>();
case 5: // 32 switch (pow2_index) {
SoftmaxMaskFuseGPUKernel<T, 5><<<blocks, threads, 0, stream>>>( case 5: // 32
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 5><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 6: // 64 break;
SoftmaxMaskFuseGPUKernel<T, 6><<<blocks, threads, 0, stream>>>( case 6: // 64
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 6><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 7: // 128 break;
SoftmaxMaskFuseGPUKernel<T, 7><<<blocks, threads, 0, stream>>>( case 7: // 128
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 7><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 8: // 256 break;
SoftmaxMaskFuseGPUKernel<T, 8><<<blocks, threads, 0, stream>>>( case 8: // 256
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 8><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 9: // 512 break;
SoftmaxMaskFuseGPUKernel<T, 9><<<blocks, threads, 0, stream>>>( case 9: // 512
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 9><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 10: // 1024 break;
SoftmaxMaskFuseGPUKernel<T, 10><<<blocks, threads, 0, stream>>>( case 10: // 1024
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 10><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 11: // 2048 break;
SoftmaxMaskFuseGPUKernel<T, 11><<<blocks, threads, 0, stream>>>( case 11: // 2048
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 11><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 12: // 4096 break;
SoftmaxMaskFuseGPUKernel<T, 12><<<blocks, threads, 0, stream>>>( case 12: // 4096
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 12><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
case 13: // 8192 break;
SoftmaxMaskFuseGPUKernel<T, 13><<<blocks, threads, 0, stream>>>( case 13: // 8192
x_data, mask_data, y_data, batch_count, key_seq_len); SoftmaxMaskFuseGPUKernel<T, T, 13><<<blocks, threads, 0, stream>>>(
break; x_data, mask_data, y_data, batch_count, key_seq_len);
default: break;
break; default:
break;
}
} else if (mask.dtype() == phi::DataType::FLOAT32) {
auto* mask_data = mask.data<float>();
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGPUKernel<T, float, 5><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGPUKernel<T, float, 6><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGPUKernel<T, float, 7><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGPUKernel<T, float, 8><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGPUKernel<T, float, 9><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGPUKernel<T, float, 10><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGPUKernel<T, float, 11><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGPUKernel<T, float, 12><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGPUKernel<T, float, 13><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
default:
break;
}
} }
} }
......
...@@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest): ...@@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest):
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSoftmaxMaskFuseOp01(OpTest):
def setUp(self):
self.op_type = "fused_softmax_mask"
self.python_api = paddle.incubate.softmax_mask_fuse
x = np.random.random((1, 1, 8, 32)).astype("float16")
mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32")
mask_input = np.where(mask == 1, -10000.0, mask)
self.inputs = {'X': x, 'Mask': mask_input}
rst = _get_softmax(x, mask_input)
self.outputs = {'Out': rst}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册