From b2d550ab850948458abacb167577603bd7b3ab5f Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 9 Sep 2022 09:31:39 -0700 Subject: [PATCH] Unit test for bias add kernel (#2298) * added unit test * Update pt_binding.cpp * formatting * Update test_bias_add.py --- csrc/transformer/inference/csrc/gelu.cu | 1 + .../transformer/inference/csrc/pt_binding.cpp | 26 ++++++++-- .../transformer/inference/test_bias_add.py | 48 +++++++++++++++++++ 3 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 tests/unit/ops/transformer/inference/test_bias_add.py diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index f6ac9386..ee952f50 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -92,6 +92,7 @@ void launch_bias_gelu(T* input, template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); +// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc. __global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) { float4* input_cast = reinterpret_cast(input); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index fee42521..592812c7 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -546,6 +546,22 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) return input_cont; } +template +at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int hidden_size = input_cont.size(2); + + launch_bias_add((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + hidden_size, + bsz, + Context::Instance().GetCurrentStream()); + return input_cont; +} + template at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) { @@ -1323,17 +1339,19 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)"); + m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)"); m.def( "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); m.def("softmax_context_fp16", &ds_softmax_context<__half>, - "DeepSpeed attention with fp32 (CUDA)"); + "DeepSpeed attention with fp16 (CUDA)"); m.def("softmax_context_int8", &ds_softmax_context1<__half>, - "DeepSpeed attention with fp32 (CUDA)"); + "DeepSpeed attention with int8 (CUDA)"); m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); + m.def("bias_add_fp32", &ds_bias_add, "DeepSpeed Bias Add with fp32 (CUDA)"); + m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); m.def("bias_relu_fp32", &ds_bias_relu, "DeepSpeed ReLU with fp32 (CUDA)"); m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)"); m.def("bias_residual_fp32", @@ -1341,7 +1359,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "DeepSpeed residual-bias add with fp32 (CUDA)"); m.def("bias_residual_fp16", &ds_bias_residual<__half>, - "DeepSpeed residual-bias add with fp32 (CUDA)"); + "DeepSpeed residual-bias add with fp16 (CUDA)"); m.def("layer_norm_fp32", &ds_layernorm, "DeepSpeed layer-norm with fp32 (CUDA)"); m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)"); m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py new file mode 100644 index 00000000..2077390a --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -0,0 +1,48 @@ +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", + allow_module_level=True) + +inference_module = None +torch_minor_version = None + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_bias_add_reference(activations, bias): + return activations + bias + + +def run_bias_add_ds(activations, bias): + global inference_module + if inference_module is None: + inference_module = InferenceBuilder().load() + if activations.dtype == torch.float16: + return inference_module.bias_add_fp16(activations, bias) + else: + return inference_module.bias_add_fp32(activations, bias) + + +@pytest.mark.inference +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255]) +@pytest.mark.parametrize("channels", [512, 1232, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +def test_bias_add(batch, sequence, channels, dtype): + activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') + bias_ds = torch.randn((channels), dtype=dtype, device='cuda') + + activations_ref = activations_ds.clone().detach() + bias_ref = bias_ds.clone().detach() + + ds_out = run_bias_add_ds(activations_ds, bias_ds) + ref_out = run_bias_add_reference(activations_ref, bias_ref) + assert allclose(ds_out, ref_out) -- GitLab