未验证 提交 b2d550ab 编写于 作者: M Michael Wyatt 提交者: GitHub

Unit test for bias add kernel (#2298)

* added unit test

* Update pt_binding.cpp

* formatting

* Update test_bias_add.py
上级 4524b8db
......@@ -92,6 +92,7 @@ void launch_bias_gelu(T* input,
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc.
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
......
......@@ -546,6 +546,22 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
return input_cont;
}
template <typename T>
at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int hidden_size = input_cont.size(2);
launch_bias_add((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
......@@ -1323,17 +1339,19 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
"DeepSpeed attention with fp16 (CUDA)");
m.def("softmax_context_int8",
&ds_softmax_context1<__half>,
"DeepSpeed attention with fp32 (CUDA)");
"DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
......@@ -1341,7 +1359,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
......
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system",
allow_module_level=True)
inference_module = None
torch_minor_version = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def run_bias_add_reference(activations, bias):
return activations + bias
def run_bias_add_ds(activations, bias):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
if activations.dtype == torch.float16:
return inference_module.bias_add_fp16(activations, bias)
else:
return inference_module.bias_add_fp32(activations, bias)
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
def test_bias_add(batch, sequence, channels, dtype):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
bias_ds = torch.randn((channels), dtype=dtype, device='cuda')
activations_ref = activations_ds.clone().detach()
bias_ref = bias_ds.clone().detach()
ds_out = run_bias_add_ds(activations_ds, bias_ds)
ref_out = run_bias_add_reference(activations_ref, bias_ref)
assert allclose(ds_out, ref_out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册