提交 104a469d 编写于 作者: T TensorFlower Gardener

Merge pull request #30614 from ROCmSoftwarePlatform:google_upstream_einsum_op

PiperOrigin-RevId: 258349270
...@@ -14,9 +14,9 @@ limitations under the License. ...@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/einsum_op.h" #include "tensorflow/core/kernels/einsum_op.h"
...@@ -39,9 +39,9 @@ limitations under the License. ...@@ -39,9 +39,9 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/einsum_op_util.h" #include "tensorflow/core/util/einsum_op_util.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/reduction_ops_common_gpu.h" #include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow { namespace tensorflow {
...@@ -709,7 +709,7 @@ class EinsumOp : public OpKernel { ...@@ -709,7 +709,7 @@ class EinsumOp : public OpKernel {
bool output_has_ellipsis_ = false; bool output_has_ellipsis_ = false;
}; };
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU. // Forward declarations of the functor specializations for GPU.
namespace functor { namespace functor {
#define DECLARE_GPU_SPEC(T, N) \ #define DECLARE_GPU_SPEC(T, N) \
...@@ -736,12 +736,15 @@ namespace functor { ...@@ -736,12 +736,15 @@ namespace functor {
DECLARE_GPU_SPECS(double); DECLARE_GPU_SPECS(double);
DECLARE_GPU_SPECS(float); DECLARE_GPU_SPECS(float);
// TODO(rocm): Enable once complex types are supported.
#if GOOGLE_CUDA
DECLARE_GPU_SPECS(complex64); DECLARE_GPU_SPECS(complex64);
DECLARE_GPU_SPECS(complex128); DECLARE_GPU_SPECS(complex128);
#endif
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
#undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_EINSUM(D, TYPE) \ #define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
...@@ -755,14 +758,17 @@ TF_CALL_complex64(REGISTER_CPU); ...@@ -755,14 +758,17 @@ TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU);
#undef REGISTER_CPU #undef REGISTER_CPU
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE) #define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE)
TF_CALL_float(REGISTER_GPU); TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU); TF_CALL_double(REGISTER_GPU);
// TODO(rocm): Enable once complex types are supported.
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU);
#endif
#undef REGISTER_GPU #undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_EINSUM #undef REGISTER_EINSUM
......
...@@ -18,9 +18,9 @@ limitations under the License. ...@@ -18,9 +18,9 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow { namespace tensorflow {
namespace functor { namespace functor {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
...@@ -43,4 +43,4 @@ TF_CALL_complex128(DECLARE_GPU_SPECS); ...@@ -43,4 +43,4 @@ TF_CALL_complex128(DECLARE_GPU_SPECS);
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册