未验证 提交 720d1899 编写于 作者: Y Yibing Liu 提交者: GitHub

Init complex number neural network (#24018)

* Init complex number neural network, test=develop

* Improve doc writing, test=develop

* Fix elementwise add & sub, test=develop

* Fix elementwise mul act, test=develop

* a) add ut for complex variable; b) remove arg act in elementwise_ops. test=develop
上级 34d7d6ae
......@@ -38,6 +38,7 @@ import paddle.tensor
import paddle.nn
import paddle.framework
import paddle.imperative
import paddle.complex
# TODO: define alias in tensor and framework directory
# from .tensor.creation import create_.tensor #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import tensor
from .tensor import *
__all__ = tensor.__all__ + []
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fluid import framework
def is_complex(x):
"""
Return true if the input(x) is a ComplexVariable.
"""
return isinstance(x, framework.ComplexVariable)
def is_real(x):
"""
Return true if the input(x) is a real number Variable.
"""
return isinstance(x, framework.Variable)
def complex_variable_exists(inputs, layer_name):
for inp in inputs:
if is_complex(inp):
return
err_msg = "At least one inputs of layer complex." if len(inputs) > 1 \
else "The input of layer complex."
raise ValueError(err_msg + layer_name +
"() must be ComplexVariable, please "
"use the layer for real numher instead.")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import math
from .math import *
__all__ = math.__all__ + []
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..helper import is_complex, is_real, complex_variable_exists
from ...fluid.framework import ComplexVariable
from ...fluid import layers
__all__ = [
'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div'
]
def elementwise_add(x, y, axis=-1, name=None):
"""
The element-wise addition layer for complex number inputs. At least one of
inputs :attr:`x` and :attr:`y` must be a ComplexVariable. See the detailed
description for the function and other arguments
in :ref:`api_fluid_layers_elementwise_add` .
Args:
x (Variable|ComplexVariable): The first input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
y (Variable|ComplexVariable): The second input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]])
b = np.array([[5.0+2.0j, 6.0+2.0j], [7.0+2.0j, 8.0+2.0j]])
with dg.guard():
x = dg.to_variable(a)
y = dg.to_variable(b)
out = paddle.complex.elementwise_add(x, y)
print(out.numpy())
# [[ 6.+3.j 8.+3.j]
# [10.+3.j 12.+3.j]]
"""
complex_variable_exists([x, y], "elementwise_add")
(x_real, x_imag) = (x.real, x.imag) if is_complex(x) else (x, None)
(y_real, y_imag) = (y.real, y.imag) if is_complex(y) else (y, None)
real = layers.elementwise_add(x_real, y_real, axis=axis, name=name)
if is_real(x_imag) and is_real(y_imag):
imag = layers.elementwise_add(x_imag, y_imag, axis=axis, name=name)
elif is_real(x_imag):
imag = layers.assign(x_imag)
else:
imag = layers.elementwise_add(
layers.zeros_like(x_real), y_imag, axis=axis, name=name)
return ComplexVariable(real, imag)
def elementwise_sub(x, y, axis=-1, name=None):
"""
The element-wise subtraction layer for complex number inputs. At least one of
inputs :attr:`x` and :attr:`y` must be a ComplexVariable. See the detailed
description for the function and other arguments
in :ref:`api_fluid_layers_elementwise_sub` .
Args:
x (Variable|ComplexVariable): The first input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
y (Variable|ComplexVariable): The second input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]])
b = np.array([[5.0+2.0j, 6.0+2.0j], [7.0+2.0j, 8.0+2.0j]])
with dg.guard():
x = dg.to_variable(a)
y = dg.to_variable(b)
out = paddle.complex.elementwise_sub(x, y)
print(out.numpy())
# [[-4.-1.j -4.-1.j]
# [-4.-1.j -4.-1.j]]
"""
complex_variable_exists([x, y], "elementwise_sub")
(x_real, x_imag) = (x.real, x.imag) if is_complex(x) else (x, None)
(y_real, y_imag) = (y.real, y.imag) if is_complex(y) else (y, None)
real = layers.elementwise_sub(x_real, y_real, axis=axis, name=name)
if is_real(x_imag) and is_real(y_imag):
imag = layers.elementwise_sub(x_imag, y_imag, axis=axis, name=name)
elif is_real(x_imag):
imag = layers.assign(x_imag)
else:
imag = layers.elementwise_sub(
layers.zeros_like(x_real), y_imag, axis=axis, name=name)
return ComplexVariable(real, imag)
def elementwise_mul(x, y, axis=-1, name=None):
"""
The element-wise multiplication layer for complex number inputs. At least
one of inputs :attr:`x` and :attr:`y` must be a ComplexVariable. See the
detailed description for the function and other arguments
in :ref:`api_fluid_layers_elementwise_mul` .
Args:
x (Variable|ComplexVariable): The first input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
y (Variable|ComplexVariable): The second input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]])
b = np.array([[5.0+2.0j, 6.0+2.0j], [7.0+2.0j, 8.0+2.0j]])
with dg.guard():
x = dg.to_variable(a)
y = dg.to_variable(b)
out = paddle.complex.elementwise_mul(x, y)
print(out.numpy())
# [[ 3. +7.j 10.+10.j]
# [19.+13.j 30.+16.j]]
"""
complex_variable_exists([x, y], "elementwise_mul")
# (a + bi)(c + di) = (ac - bd) + (bc + ad)i
(a, b) = (x.real, x.imag) if is_complex(x) else (x, None)
(c, d) = (y.real, y.imag) if is_complex(y) else (y, None)
ac = layers.elementwise_mul(a, c, axis=axis, name=name)
bd = layers.elementwise_mul(
b, d, axis=axis, name=name) if is_real(b) and is_real(d) else None
bc = layers.elementwise_mul(
b, c, axis=axis, name=name) if is_real(b) else None
ad = layers.elementwise_mul(
a, d, axis=axis, name=name) if is_real(d) else None
real = ac - bd if is_real(bd) else ac
imag = bc + ad if is_real(bc) and is_real(ad) else bc if is_real(bc) else ad
return ComplexVariable(real, imag)
def elementwise_div(x, y, axis=-1, name=None):
"""
The element-wise division layer for complex number inputs. At least one of
inputs :attr:`x` and :attr:`y` must be a ComplexVariable. See the detailed
description for the function and other arguments
in :ref:`api_fluid_layers_elementwise_div` .
Args:
x (Variable|ComplexVariable): The first input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
y (Variable|ComplexVariable): The second input Variable or ComplexVariable
with any number of dimensions. The supported data types include float32
and float64 when it is a Variable. Otherwise the supported data types
are complex64 or complex128.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
a = np.array([[1.0+1.0j, 2.0+1.0j], [3.0+1.0j, 4.0+1.0j]])
b = np.array([[5.0+2.0j, 6.0+2.0j], [7.0+2.0j, 8.0+2.0j]])
with dg.guard():
x = dg.to_variable(a)
y = dg.to_variable(b)
out = paddle.complex.elementwise_div(x, y)
print(out.numpy())
# [[0.24137931+0.10344828j 0.35 +0.05j ]
# [0.43396226+0.01886792j 0.5 +0.j ]]
"""
complex_variable_exists([x, y], "elementwise_div")
# (a + bi)/(c + di) = (a + bi)(c - di)/(c^2 + d^2)
(c, d) = (y.real, y.imag) if is_complex(y) else (y, None)
y_conj = ComplexVariable(c, -d) if is_real(d) else c
e = 1 / (layers.pow(c, 2.0) + layers.pow(d, 2.0)
) if is_real(d) else 1 / layers.pow(c, 2.0)
return elementwise_mul(
elementwise_mul(
x, y_conj, axis=axis, name=name),
e,
axis=axis,
name=name)
......@@ -492,15 +492,24 @@ def grad(outputs,
@framework.dygraph_only
def to_variable(value, name=None, zero_copy=None):
"""
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
The API will create a ``Variable`` or ``ComplexVariable`` object from
numpy\.ndarray, Variable or ComplexVariable object.
Parameters:
value(ndarray|Variable): The numpy\.ndarray or Variable object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
zero_copy(bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace and will be set to True when it is None. Default: None.
value(ndarray|Variable|ComplexVariable): The numpy\.ndarray, Variable
or ComplexVariable object that needs to be converted, it can be
multi-dimension, and the data type is one of numpy\.{float16,
float32, float64, int16, int32, int64, uint8, uint16, complex64,
complex128}.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
zero_copy(bool, optional): Whether to share memory with the input numpy
array. This parameter only works with CPUPlace and will be set to
True when it is None. Default: None.
Returns:
Variable: If ``value`` is a numpy\.ndarray object, return ``Tensor`` created from the specified numpy\.ndarray object, which has same data type and shape with ``value``. If ``value`` is a Variable object, just return ``value``.
Variable or ComplexVariable: If ``value`` is a numpy\.ndarray object, return ``Tensor`` created from the specified numpy\.ndarray object, which has same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable object, just return ``value``.
Examples:
......@@ -518,7 +527,10 @@ def to_variable(value, name=None, zero_copy=None):
y = fluid.dygraph.to_variable(x)
x[0][0] = 0
y[0][0].numpy() # array([0.], dtype=float32)
c = np.array([2+1j, 2])
z = fluid.dygraph.to_variable(c)
z.numpy() # array([2.+1.j, 2.+0.j])
z.dtype # 'complex128'
"""
if isinstance(value, np.ndarray):
assert framework.in_dygraph_mode(
......@@ -530,16 +542,34 @@ def to_variable(value, name=None, zero_copy=None):
else:
assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
zero_copy = False
py_var = core.VarBase(
value=value,
place=framework._current_expected_place(),
persistable=False,
zero_copy=zero_copy,
name=name if name else '')
return py_var
elif isinstance(value, (core.VarBase, framework.Variable)):
if np.iscomplexobj(value):
if not name:
name = framework.unique_name.generate('_generated_var')
real_var = core.VarBase(
value=value.real,
place=framework._current_expected_place(),
persistable=False,
zero_copy=zero_copy,
name=name + ".real")
imag_var = core.VarBase(
value=value.imag,
place=framework._current_expected_place(),
persistable=False,
zero_copy=zero_copy,
name=name + ".imag")
return framework.ComplexVariable(real_var, imag_var)
else:
py_var = core.VarBase(
value=value,
place=framework._current_expected_place(),
persistable=False,
zero_copy=zero_copy,
name=name if name else '')
return py_var
elif isinstance(value, (core.VarBase, framework.Variable,
framework.ComplexVariable)):
return value
else:
raise TypeError(
"The type of input value is invalid, expected type is 'ndarray' or 'Variable', but received %s"
% type(value))
"The type of input value is invalid, expected type is 'ndarray', "
"'Variable' or 'ComplexVariable', but received %s." % type(value))
......@@ -49,6 +49,7 @@ __all__ = [
'in_dygraph_mode',
'is_compiled_with_cuda',
'Variable',
'ComplexVariable',
'load_op_library',
'require_version',
'device_guard',
......@@ -1657,6 +1658,88 @@ def get_all_op_protos():
return ret_values
class ComplexVariable(object):
"""
The Variable defined on the complex number domain. It contains two common
real number Variables as its members, :attr:`real` and :attr:`imag`
holding the real part and imaginary part of complex numbers respectively.
**Notes**:
**The constructor of Variable should not be invoked directly.**
**Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **
to create a dygraph ComplexVariable with complex number data.**
Args:
real (Variable): The Variable holding real-part data.
imag (Variable): The Variable holding imaginery-part data.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
a = np.array([1.0+2.0j, 0.2])
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(a, name="new_var")
print(var.name, var.dtype, var.shape)
# ({'real': u'new_var.real', 'imag': u'new_var.imag'}, 'complex128', [2L])
print(var.numpy())
# [1. +2.j 0.2+0.j]
"""
def __init__(self, real, imag):
assert real.shape == imag.shape, "The real part and imaginary part " \
"of a ComplexVariable should have the same shape!"
assert real.dtype == imag.dtype, "The real part and imaginary part " \
"of a ComplexVariable should have the same data type!"
self.real = real
self.imag = imag
if self.real.dtype in [
core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32
]:
self._dtype = "complex64"
else:
self._dtype = "complex128"
self._shape = self.real.shape
@property
def dtype(self):
return self._dtype
@property
def shape(self):
return self._shape
@property
def name(self):
return {"real": self.real.name, "imag": self.imag.name}
@name.setter
def name(self, name):
# rename
if isinstance(name, str):
self.real.name = name + ".real"
self.imag.name = name + ".imag"
elif (isinstance(name, tuple) or isinstance(name,
list)) and len(name) == 2:
self.real.name, self.imag.name = name[0], name[1]
else:
raise ValueError(
"An invalid name assigned to the ComplexVariable, "
"which must be a string, or a tuple or a list with length 2!")
def numpy(self):
return self.real.numpy() + 1j * self.imag.numpy()
def __str__(self):
return "REAL: " + self.real.__str__() + "IMAG: " + self.imag.__str__()
__repr__ = __str__
class OpProtoHolder(object):
"""
A global variable to hold all OpProtos from C++ as a map
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from numpy.random import random as rand
import paddle.complex as cpx
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
layers = {
"add": cpx.elementwise_add,
"sub": cpx.elementwise_sub,
"mul": cpx.elementwise_mul,
"div": cpx.elementwise_div,
}
class TestComplexElementwiseLayers(unittest.TestCase):
def setUp(self):
self._dtype = "float64"
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
def calc(self, x, y, layer_type, place):
with dg.guard(place):
var_x = dg.to_variable(x)
var_y = dg.to_variable(y)
return layers[layer_type](var_x, var_y).numpy()
def compare(self, x, y):
for place in self._places:
self.assertTrue(np.allclose(self.calc(x, y, "add", place), x + y))
self.assertTrue(np.allclose(self.calc(x, y, "sub", place), x - y))
self.assertTrue(np.allclose(self.calc(x, y, "mul", place), x * y))
self.assertTrue(np.allclose(self.calc(x, y, "div", place), x / y))
def test_complex_xy(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype)
y = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype)
self.compare(x, y)
def test_complex_x_real_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype)
y = rand([4, 5]).astype(self._dtype)
self.compare(x, y)
def test_real_x_complex_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype)
y = rand([5]).astype(self._dtype) + 1j * rand([5]).astype(self._dtype)
self.compare(x, y)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
class TestComplexVariable(unittest.TestCase):
def compare(self):
a = np.array([[1.0 + 1.0j, 2.0 + 1.0j],
[3.0 + 1.0j, 4.0 + 1.0j]]).astype(self._dtype)
b = np.array([[1.0 + 1.0j, 1.0 + 1.0j]]).astype(self._dtype)
with dg.guard():
x = dg.to_variable(a, "x")
y = dg.to_variable(b)
out = paddle.complex.elementwise_add(x, y)
self.assertIsNotNone("{}".format(out))
self.assertTrue(np.allclose(out.numpy(), a + b))
self.assertEqual(x.name, {'real': 'x.real', 'imag': 'x.imag'})
x.name = "new_x"
self.assertEqual(x.name, {'real': 'new_x.real', 'imag': 'new_x.imag'})
self.assertEqual(out.dtype, self._dtype)
self.assertEqual(out.shape, x.shape)
def test_attrs(self):
self._dtype = "complex64"
self.compare()
self._dtype = "complex128"
self.compare()
if __name__ == '__main__':
unittest.main()
......@@ -139,11 +139,11 @@ packages=['paddle',
'paddle.dataset',
'paddle.reader',
'paddle.distributed',
'paddle.fluid',
'paddle.tensor',
'paddle.complex',
'paddle.complex.tensor',
'paddle.framework',
'paddle.fluid',
'paddle.fluid.dygraph',
'paddle.tensor',
'paddle.fluid.dygraph.dygraph_to_static',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
......@@ -181,7 +181,9 @@ packages=['paddle',
'paddle.nn',
'paddle.nn.functional',
'paddle.nn.layer',
'paddle.imperative']
'paddle.imperative',
'paddle.tensor',
]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
setup_requires = f.read().splitlines()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册