diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index ff810493c6cc8fe20a236de5ccc887ab0d87170c..a4a81e3f9e986a9cb578d3d05fd3d495d1709fc5 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import os import platform +from typing import Tuple import numpy as np import pytest @@ -8,6 +9,7 @@ from utils import get_var_value, make_tensor, opr_test import megengine.functional as F from megengine import Tensor +from megengine.core._imperative_rt.core2 import create_complex from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d @@ -1076,3 +1078,23 @@ def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): np.testing.assert_equal(out.numpy(), out_ref) if is_symbolic is None: break + + +def test_polar(): + def polar(abs, angle): + return F.polar(abs, angle) + + def numpy_polar(abs, angle): + return abs * np.cos(angle) + abs * np.sin(angle) * 1j + + cases = [{"input": [np.random.random((2, 3, 4)), np.random.random((2, 3, 4))]}] + + # complex can not be trace output + opr_test(cases, polar, ref_fn=numpy_polar, test_trace=False) + + +def test_create_complex(): + real = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32")) + imag = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32")) + complex = create_complex(real, imag) + np.testing.assert_allclose(complex.numpy(), real.numpy() + imag.numpy() * 1j)