提交 ea3cf7f8 编写于 作者: M Megvii Engine Team

test(complex): add simple rope tests for complex dtype

GitOrigin-RevId: cf60f1659d68c61f1f25db3247811f9c4b0c1f47
上级 74b8af4d
# -*- coding: utf-8 -*-
import os
import platform
from typing import Tuple
import numpy as np
import pytest
......@@ -8,6 +9,7 @@ from utils import get_var_value, make_tensor, opr_test
import megengine.functional as F
from megengine import Tensor
from megengine.core._imperative_rt.core2 import create_complex
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d
......@@ -1076,3 +1078,23 @@ def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
np.testing.assert_equal(out.numpy(), out_ref)
if is_symbolic is None:
break
def test_polar():
def polar(abs, angle):
return F.polar(abs, angle)
def numpy_polar(abs, angle):
return abs * np.cos(angle) + abs * np.sin(angle) * 1j
cases = [{"input": [np.random.random((2, 3, 4)), np.random.random((2, 3, 4))]}]
# complex can not be trace output
opr_test(cases, polar, ref_fn=numpy_polar, test_trace=False)
def test_create_complex():
real = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32"))
imag = Tensor(np.arange(0, 6).reshape((1, 2, 3)).astype("float32"))
complex = create_complex(real, imag)
np.testing.assert_allclose(complex.numpy(), real.numpy() + imag.numpy() * 1j)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册