From ea3cf7f82105dbb8e4a951fbc8083a9773b99e93 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 May 2023 14:14:15 +0800 Subject: [PATCH] test(complex): add simple rope tests for complex dtype GitOrigin-RevId: cf60f1659d68c61f1f25db3247811f9c4b0c1f47 --- .../test/unit/functional/test_tensor.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index ff810493c..a4a81e3f9 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os import platform +from typing import Tuple import numpy as np import pytest @@ -8,6 +9,7 @@ from utils import get_var_value, make_tensor, opr_test import megengine.functional as F from megengine import Tensor +from megengine.core._imperative_rt.core2 import create_complex from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d @@ -1076,3 +1078,23 @@ def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): np.testing.assert_equal(out.numpy(), out_ref) if is_symbolic is None: break + + +def test_polar(): + def polar(abs, angle): + return F.polar(abs, angle) + + def numpy_polar(abs, angle): + return abs * np.cos(angle) + abs * np.sin(angle) * 1j + + cases = [{"input": [np.random.random((2, 3, 4)), np.random.random((2, 3, 4))]}] + + # complex can not be trace output + opr_test(cases, polar, ref_fn=numpy_polar, test_trace=False) + + +def test_create_complex(): + real = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32")) + imag = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32")) + complex = create_complex(real, imag) + np.testing.assert_allclose(complex.numpy(), real.numpy() + imag.numpy() * 1j) -- GitLab