未验证 提交 6cebd714 编写于 作者: C chentianyu03 提交者: GitHub

add + - * / @ [] operator to ComplexVariable (#28217)

* add + - * / @ [] operator to ComplexVariable, also add unittest

* fix circular reference bug

* fit for py2.7

* remove reverse oprators which not supported now
上级 a98c69b6
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from .. import core from .. import core
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator, ComplexVariable
from ..layers.layer_function_generator import OpProtoHolder from ..layers.layer_function_generator import OpProtoHolder
from . import no_grad from . import no_grad
...@@ -149,6 +149,13 @@ def monkey_patch_math_varbase(): ...@@ -149,6 +149,13 @@ def monkey_patch_math_varbase():
reverse=False, reverse=False,
scalar_method=None): scalar_method=None):
def __impl__(self, other_var): def __impl__(self, other_var):
# tensor and ComplexVariable opetator
if isinstance(other_var, ComplexVariable):
# need import paddle in closure
import paddle
math_op = getattr(paddle.incubate.complex.tensor, op_type)
return math_op(self, other_var)
# FIXME(zjl): elementwise_div between integers cannot be converted to scale, # FIXME(zjl): elementwise_div between integers cannot be converted to scale,
# which may lose accuracy. This is a hot fix for release 1.6. # which may lose accuracy. This is a hot fix for release 1.6.
if scalar_method is not None and not ( if scalar_method is not None and not (
......
...@@ -1826,6 +1826,9 @@ class ComplexVariable(object): ...@@ -1826,6 +1826,9 @@ class ComplexVariable(object):
self._dtype = "complex128" self._dtype = "complex128"
self._shape = self.real.shape self._shape = self.real.shape
def __getitem__(self, idx):
return ComplexVariable(self.real[idx], self.imag[idx])
@property @property
def dtype(self): def dtype(self):
return self._dtype return self._dtype
......
...@@ -47,23 +47,36 @@ class TestComplexElementwiseLayers(unittest.TestCase): ...@@ -47,23 +47,36 @@ class TestComplexElementwiseLayers(unittest.TestCase):
self.assertTrue(np.allclose(self.calc(x, y, "mul", place), x * y)) self.assertTrue(np.allclose(self.calc(x, y, "mul", place), x * y))
self.assertTrue(np.allclose(self.calc(x, y, "div", place), x / y)) self.assertTrue(np.allclose(self.calc(x, y, "div", place), x / y))
def compare_op(self, x, y):
for place in self._places:
with dg.guard(place):
var_x = dg.to_variable(x)
var_y = dg.to_variable(y)
self.assertTrue(var_x + var_y, x + y)
self.assertTrue(var_x - var_y, x - y)
self.assertTrue(var_x * var_y, x * y)
self.assertTrue(var_x / var_y, x / y)
def test_complex_xy(self): def test_complex_xy(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype) [2, 3, 4, 5]).astype(self._dtype)
y = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( y = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype) [2, 3, 4, 5]).astype(self._dtype)
self.compare(x, y) self.compare(x, y)
self.compare_op(x, y)
def test_complex_x_real_y(self): def test_complex_x_real_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype) [2, 3, 4, 5]).astype(self._dtype)
y = rand([4, 5]).astype(self._dtype) y = rand([4, 5]).astype(self._dtype)
self.compare(x, y) self.compare(x, y)
self.compare_op(x, y)
def test_real_x_complex_y(self): def test_real_x_complex_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) x = rand([2, 3, 4, 5]).astype(self._dtype)
y = rand([5]).astype(self._dtype) + 1j * rand([5]).astype(self._dtype) y = rand([5]).astype(self._dtype) + 1j * rand([5]).astype(self._dtype)
self.compare(x, y) self.compare(x, y)
self.compare_op(x, y)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
class TestComplexGetitemLayer(unittest.TestCase):
def setUp(self):
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
def test_case1(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case2(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0][1]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case3(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][2]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0][1][2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case4(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][0:3]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0][1][0:3]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case5(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][0:4:2]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0][1][0:4:2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case6(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1:3][0:4:2]
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x_np)
x_var_slice = x_var[0][1:3][0:4:2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
if __name__ == '__main__':
unittest.main()
...@@ -34,6 +34,15 @@ class TestComplexMatMulLayer(unittest.TestCase): ...@@ -34,6 +34,15 @@ class TestComplexMatMulLayer(unittest.TestCase):
np_result = np.matmul(x, y) np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result)) self.assertTrue(np.allclose(result.numpy(), np_result))
def compare_op(self, x, y):
for place in self._places:
with dg.guard(place):
x_var = dg.to_variable(x)
y_var = dg.to_variable(y)
result = x_var.matmul(y_var)
np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result))
def test_complex_xy(self): def test_complex_xy(self):
x = np.random.random( x = np.random.random(
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random( (2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
...@@ -42,6 +51,7 @@ class TestComplexMatMulLayer(unittest.TestCase): ...@@ -42,6 +51,7 @@ class TestComplexMatMulLayer(unittest.TestCase):
(2, 3, 5, 4)).astype("float32") + 1J * np.random.random( (2, 3, 5, 4)).astype("float32") + 1J * np.random.random(
(2, 3, 5, 4)).astype("float32") (2, 3, 5, 4)).astype("float32")
self.compare(x, y) self.compare(x, y)
self.compare_op(x, y)
def test_complex_x(self): def test_complex_x(self):
x = np.random.random( x = np.random.random(
...@@ -49,6 +59,7 @@ class TestComplexMatMulLayer(unittest.TestCase): ...@@ -49,6 +59,7 @@ class TestComplexMatMulLayer(unittest.TestCase):
(2, 3, 4, 5)).astype("float32") (2, 3, 4, 5)).astype("float32")
y = np.random.random((2, 3, 5, 4)).astype("float32") y = np.random.random((2, 3, 5, 4)).astype("float32")
self.compare(x, y) self.compare(x, y)
self.compare_op(x, y)
def test_complex_y(self): def test_complex_y(self):
x = np.random.random((2, 3, 4, 5)).astype("float32") x = np.random.random((2, 3, 4, 5)).astype("float32")
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
from . import tensor from . import tensor
from .tensor_op_patch import monkey_patch_math_complex
from .tensor import * from .tensor import *
__all__ = tensor.__all__ + [] __all__ = tensor.__all__ + []
monkey_patch_math_complex()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from ...fluid import framework
from . import tensor
def monkey_patch_math_complex():
# complexVariable do not support scaler type now, so here not contains
# reverse methods, such as "__radd__", "__rsub__", "__rmul__", "__rdiv__",
# "__rtruediv__", "__rmatmul__".
complex_methods = [
('__add__', _binary_creator_('__add__', "elementwise_add", False)),
('__sub__', _binary_creator_('__sub__', "elementwise_sub", False)),
('__mul__', _binary_creator_('__mul__', "elementwise_mul", False)),
('__div__', _binary_creator_('__div__', "elementwise_div", False)),
('__truediv__', _binary_creator_('__truediv__', "elementwise_div",
False)),
('__matmul__', _binary_creator_('__matmul__', "matmul", False)),
]
for method in complex_methods:
method_name = method[0]
method_impl = method[1]
if method_impl:
setattr(framework.ComplexVariable, method_name, method_impl)
for method in tensor.__all__:
method_impl = getattr(tensor, method)
if method_impl:
setattr(framework.ComplexVariable, method, method_impl)
# for binary operator such as elementwise
def _binary_creator_(method_name, op_type, reverse=False):
def __impl__(self, other_var):
math_op = getattr(tensor, op_type)
return math_op(self, other_var)
__impl__.__name__ = method_name
return __impl__
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册