提交 3a552b92 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add pfor converters for

 FFT
 FFT2D
 FFT3D
 IFFT
 IFFT2D
 IFFT3D
 RFFT
 RFFT2D
 RFFT3D
 IRFFT
 IRFFT2D
 IRFFT3D

PiperOrigin-RevId: 295026807
Change-Id: I572e2f8fbe94eb9f30a2a17fc929190fbb5450df
上级 e46798a8
......@@ -41,6 +41,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/ops/signal",
"@absl_py//absl/flags",
],
)
......
......@@ -23,6 +23,7 @@ import functools
import time
from absl import flags
from absl.testing import parameterized
import numpy as np
from tensorflow.core.example import example_pb2
......@@ -58,6 +59,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
......@@ -1602,6 +1604,65 @@ class PartitionedCallTest(PForTestCase):
self._test_loop_fn(loop_fn, 4)
class SpectralTest(PForTestCase, parameterized.TestCase):
@parameterized.parameters(
(fft_ops.fft,),
(fft_ops.fft2d,),
(fft_ops.fft3d,),
(fft_ops.ifft,),
(fft_ops.ifft2d,),
(fft_ops.ifft3d,),
)
def test_fft(self, op_func):
shape = [2, 3, 4, 3, 4]
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
def loop_fn(i):
x_i = array_ops.gather(x, i)
return op_func(x_i)
self._test_loop_fn(loop_fn, 2)
@parameterized.parameters(
(fft_ops.rfft,),
(fft_ops.rfft2d,),
(fft_ops.rfft3d,),
)
def test_rfft(self, op_func):
for dtype in (dtypes.float32, dtypes.float64):
x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
x_i = array_ops.gather(x, i)
return op_func(x_i)
# pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 2)
@parameterized.parameters(
(fft_ops.irfft,),
(fft_ops.irfft2d,),
(fft_ops.irfft3d,),
)
def test_irfft(self, op_func):
for dtype in (dtypes.complex64, dtypes.complex128):
shape = [2, 3, 4, 3, 4]
x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
x = math_ops.cast(x, dtype=dtype)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
x_i = array_ops.gather(x, i)
return op_func(x_i)
# pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 2)
class VariableTest(PForTestCase):
def test_create_variable_once(self):
......
......@@ -50,6 +50,7 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
......@@ -3619,3 +3620,29 @@ def _convert_partitioned_call(pfor_input):
wrap(call_output, func_output.is_stacked,
func_output.is_sparse_stacked))
return outputs
# spectral_ops
@RegisterPForWithArgs("FFT", gen_spectral_ops.fft)
@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d)
@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d)
@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft)
@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d)
@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d)
def _convert_fft(pfor_input, _, op_func):
return wrap(op_func(pfor_input.stacked_input(0)), True)
@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex")
@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex")
@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex")
@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal")
@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal")
@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal")
def _convert_rfft(pfor_input, _, op_func, attr_name):
inp = pfor_input.stacked_input(0)
fft_length = pfor_input.unstacked_input(1)
attr = pfor_input.get_attr(attr_name)
return wrap(op_func(inp, fft_length, attr), True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册