diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 928f64c290e95b0806899a55cd339ac91615b5a5..47707c7c64a5e735d4232363038538370174331b 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -471,7 +471,6 @@ __global__ void moe_res_matmul(__half* residual, { #ifdef HALF_PRECISION_AVAILABLE unsigned tid = threadIdx.x; - float2* residual_cast = reinterpret_cast(residual); float2* mlp_out_cast = reinterpret_cast(mlp_out); float2* coef_cast = reinterpret_cast(coef); @@ -483,7 +482,7 @@ __global__ void moe_res_matmul(__half* residual, while (tid < hidden_dim) { float2 res = residual_cast[tid]; float2 coef1 = coef_cast[tid]; - float2 coef2 = coef_cast[tid]; + float2 coef2 = coef_cast2[tid]; float2 data = mlp_out_cast[tid]; __half* data_h = reinterpret_cast<__half*>(&data); __half* coef1_h = reinterpret_cast<__half*>(&coef1); diff --git a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1b1cb161682019b78e2312df55ad5f5c31ec5f --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py @@ -0,0 +1,56 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", + allow_module_level=True) + +inference_module = None + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_moe_res_matmul_reference(residual, coef1, coef2, output): + return residual * coef1 + output * coef2 + + +def run_moe_res_matmul_ds(residual, coef, output): + global inference_module + if inference_module is None: + inference_module = InferenceBuilder().load() + coef_t = coef.transpose(-1, -2).contiguous() + return inference_module.moe_res_matmul(residual, coef_t, output) + + +@pytest.mark.inference +@pytest.mark.parametrize("hidden_dim", [16, 64]) +@pytest.mark.parametrize("c", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_moe_residual_matmul(hidden_dim, c, dtype): + residual_ds = torch.randn((c, + hidden_dim * c, + hidden_dim), + dtype=dtype, + device='cuda') + coeff1 = torch.randn((1, 1, hidden_dim), dtype=dtype, device='cuda') + coeff2 = torch.randn((1, 1, hidden_dim), dtype=dtype, device='cuda') + out_ds = torch.randn((c, hidden_dim * c, hidden_dim), dtype=dtype, device='cuda') + coeff_ds = torch.cat((coeff1, coeff2), dim=-1) + residual_ref = residual_ds.clone().detach() + coeff_ref = coeff_ds.clone().detach() + out_ref = out_ds.clone().detach() + + ds_out = run_moe_res_matmul_ds(residual_ds, coeff_ds, out_ds) + ref_out = run_moe_res_matmul_reference(residual_ref, coeff1, coeff2, out_ref) + + assert (allclose(ds_out, ref_out))