diff --git a/tensorflow/core/kernels/einsum_op.cc b/tensorflow/core/kernels/einsum_op.cc index bca7fca7f3d6ee2a7bbbf3be61ce7a9a490ae592..75136cd9f017adb868f5f62cb4d060448f98581a 100644 --- a/tensorflow/core/kernels/einsum_op.cc +++ b/tensorflow/core/kernels/einsum_op.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/einsum_op.h" @@ -39,9 +39,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/einsum_op_util.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/reduction_ops_common_gpu.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -709,7 +709,7 @@ class EinsumOp : public OpKernel { bool output_has_ellipsis_ = false; }; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T, N) \ @@ -736,12 +736,15 @@ namespace functor { DECLARE_GPU_SPECS(double); DECLARE_GPU_SPECS(float); +// TODO(rocm): Enable once complex types are supported. +#if GOOGLE_CUDA DECLARE_GPU_SPECS(complex64); DECLARE_GPU_SPECS(complex128); +#endif #undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPECS } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_EINSUM(D, TYPE) \ REGISTER_KERNEL_BUILDER( \ @@ -755,14 +758,17 @@ TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); #undef REGISTER_CPU -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE) TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); +// TODO(rocm): Enable once complex types are supported. +#if GOOGLE_CUDA TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +#endif #undef REGISTER_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_EINSUM diff --git a/tensorflow/core/kernels/einsum_op.h b/tensorflow/core/kernels/einsum_op.h index 8ac1bbc5fe5785f7e5754316247192203686c02a..31d1109004cc9ad20d153b05e9602df4d6349097 100644 --- a/tensorflow/core/kernels/einsum_op.h +++ b/tensorflow/core/kernels/einsum_op.h @@ -18,9 +18,9 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/einsum_op_gpu.cu.cc b/tensorflow/core/kernels/einsum_op_gpu.cu.cc index e7adbe571e7a51193aa825d9f86a8ac0227a7b1f..fa1c8cbb4a5fe83df27daff4f444e623d75da3e4 100644 --- a/tensorflow/core/kernels/einsum_op_gpu.cu.cc +++ b/tensorflow/core/kernels/einsum_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/framework/register_types.h" @@ -43,4 +43,4 @@ TF_CALL_complex128(DECLARE_GPU_SPECS); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM