diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 38ac8b1769ba05e8d1912c5d1e3e6ebaeb29d11f..88ddf7a7ec8b3cae640a3aac73c7b02d7a40ceda 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -41,6 +41,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/ops/signal", "@absl_py//absl/flags", ], ) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index fd071dd413de431b4c20a13e5cd2ce95b0414868..7d4d77a866e512d72d5bd070a09a2011ae09e2bc 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -23,6 +23,7 @@ import functools import time from absl import flags +from absl.testing import parameterized import numpy as np from tensorflow.core.example import example_pb2 @@ -58,6 +59,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops from tensorflow.python.ops.parallel_for.test_util import PForTestCase +from tensorflow.python.ops.signal import fft_ops from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -1602,6 +1604,65 @@ class PartitionedCallTest(PForTestCase): self._test_loop_fn(loop_fn, 4) +class SpectralTest(PForTestCase, parameterized.TestCase): + + @parameterized.parameters( + (fft_ops.fft,), + (fft_ops.fft2d,), + (fft_ops.fft3d,), + (fft_ops.ifft,), + (fft_ops.ifft2d,), + (fft_ops.ifft3d,), + ) + def test_fft(self, op_func): + shape = [2, 3, 4, 3, 4] + x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return op_func(x_i) + + self._test_loop_fn(loop_fn, 2) + + @parameterized.parameters( + (fft_ops.rfft,), + (fft_ops.rfft2d,), + (fft_ops.rfft3d,), + ) + def test_rfft(self, op_func): + for dtype in (dtypes.float32, dtypes.float64): + x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x_i = array_ops.gather(x, i) + return op_func(x_i) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + @parameterized.parameters( + (fft_ops.irfft,), + (fft_ops.irfft2d,), + (fft_ops.irfft3d,), + ) + def test_irfft(self, op_func): + for dtype in (dtypes.complex64, dtypes.complex128): + shape = [2, 3, 4, 3, 4] + x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) + x = math_ops.cast(x, dtype=dtype) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x_i = array_ops.gather(x, i) + return op_func(x_i) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + class VariableTest(PForTestCase): def test_create_variable_once(self): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 6eea73bfb080b13434051c14608b2a92099be052..b01f9a6aba44cf15c223fe99e2fbb2bffe6cd1bd 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_parsing_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops @@ -3619,3 +3620,29 @@ def _convert_partitioned_call(pfor_input): wrap(call_output, func_output.is_stacked, func_output.is_sparse_stacked)) return outputs + + +# spectral_ops + + +@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) +@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) +@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) +@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) +@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) +@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) +def _convert_fft(pfor_input, _, op_func): + return wrap(op_func(pfor_input.stacked_input(0)), True) + + +@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") +@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") +@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") +@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") +@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") +@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") +def _convert_rfft(pfor_input, _, op_func, attr_name): + inp = pfor_input.stacked_input(0) + fft_length = pfor_input.unstacked_input(1) + attr = pfor_input.get_attr(attr_name) + return wrap(op_func(inp, fft_length, attr), True)