From a95f95a60b20fb48fbfcae8da4afaa9412582746 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 1 Nov 2016 15:46:11 -0800 Subject: [PATCH] Remove references to gcudacc. Change: 137888607 --- .../core/common_runtime/gpu/gpu_device.cc | 76 ------------------- .../kernels/sparse_tensor_dense_matmul_op.h | 23 +----- tensorflow/core/util/cuda_kernel_helper.h | 8 -- tensorflow/stream_executor/device_memory.h | 17 ----- 4 files changed, 1 insertion(+), 123 deletions(-) diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index e9c48a36e0f..37ab43d90b0 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -72,56 +72,6 @@ namespace tensorflow { // corresponding stream have completed. The following two classes // serve this purpose in two different compilation environments. -#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) -class EigenAllocator : public ::Eigen::Allocator { - public: - EigenAllocator() {} - - void Reinitialize(OpKernelContext* context, gpu::Stream* stream, - ::tensorflow::Allocator* alloc, EventMgr* em) { - if (LogMemory::IsEnabled()) { - operation_ = context->op_kernel().name() + "/EigenAllocator"; - step_id_ = context->step_id(); - } - stream_ = stream; - allocator_ = alloc; - em_ = em; - } - - void* allocate(size_t num_bytes) const override { - void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes); - // Eigen doesn't typically check the return pointer from allocate, - // so we do it here and die with a more helpful error message. - if (ret == nullptr) { - LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating " - << num_bytes << ". See error logs for more detailed info."; - } - if (LogMemory::IsEnabled()) { - LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret, - allocator_); - } - return ret; - } - - void deallocate(void* buffer) const override { - if (LogMemory::IsEnabled()) { - LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_, - true); - } - em_->ThenDeleteBuffer(stream_, {allocator_, buffer, operation_, step_id_}); - } - - private: - string operation_; - int64 step_id_; - gpu::Stream* stream_; // Not owned. - ::tensorflow::Allocator* allocator_; // Not owned. - ::tensorflow::EventMgr* em_; // Not owned. - - TF_DISALLOW_COPY_AND_ASSIGN(EigenAllocator); -}; - -#else class EigenCudaStreamDevice : public ::Eigen::StreamInterface { public: EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { @@ -216,8 +166,6 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface { TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice); }; -#endif - BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, Bytes memory_limit, const DeviceLocality& locality, int gpu_id, const string& physical_device_desc, @@ -515,24 +463,6 @@ Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, } namespace { -#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) -class ConcretePerOpGpuDevice : public PerOpGpuDevice { - public: - ConcretePerOpGpuDevice() : device_(nullptr) {} - void Reinitialize(OpKernelContext* context, gpu::Stream* stream, - Allocator* base_allocator, ::tensorflow::EventMgr* em, - char* scratch) { - allocator_.Reinitialize(context, stream, base_allocator, em); - device_.Reinitialize(stream, &allocator_, scratch); - } - - const Eigen::GpuDevice& device() const override { return device_; } - - private: - EigenAllocator allocator_; - Eigen::GpuDevice device_; -}; -#else class ConcretePerOpGpuDevice : public PerOpGpuDevice { public: ConcretePerOpGpuDevice() : device_(&stream_device_) {} @@ -549,7 +479,6 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { EigenCudaStreamDevice stream_device_; Eigen::GpuDevice device_; }; -#endif } // namespace void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, @@ -558,15 +487,10 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, ConcretePerOpGpuDevice* concrete_device = static_cast(device); DCHECK(concrete_device); -#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) - concrete_device->Reinitialize(context, streams_[stream_id].compute, allocator, - em_.get(), scratch_[stream_id]); -#else const cudaStream_t* cuda_stream = reinterpret_cast( streams_[stream_id].compute->implementation()->CudaStreamMemberHack()); concrete_device->Reinitialize(context, cuda_stream, gpu_id_, allocator, scratch_[stream_id]); -#endif } PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() { diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index 6106328e7ef..3bec4ce5f2d 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -55,34 +55,13 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { return v; } -#ifdef __GCUDACC__ -// TODO(ebrevdo): remove this once a bugfix is in. -#define MAYBE_CONJ(T) \ - template <> \ - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { \ - assert(false && "Conjugation not supported"); \ - } -#else -#define MAYBE_CONJ(T) \ - template <> \ - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { \ - return Eigen::numext::conj(v); \ - } -#endif - -MAYBE_CONJ(std::complex); -MAYBE_CONJ(std::complex); -MAYBE_CONJ(std::complex); - -#undef MAYBE_CONJ - template class MaybeAdjoint { public: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaybeAdjoint(MATRIX m) : m_(m) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename MATRIX::Scalar operator()( const typename MATRIX::Index i, const typename MATRIX::Index j) const { - return MaybeConj(m_(j, i)); + return Eigen::numext::conj(m_(j, i)); } private: diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 488c28e5305..a0b7341c798 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -77,16 +77,8 @@ __device__ __host__ inline T ldg(const T* address) { #define CUDA_ATOMIC_WRAPPER(op, T) \ __device__ __forceinline__ T CudaAtomic##op(T* address, T val) -// Reason of guarding: NVCC cannot compile the "::" in "cuda_builtin::atomicOp". -#ifdef __GCUDACC__ -using cuda_builtin::__float_as_int; -using cuda_builtin::__int_as_float; -#define USE_CUDA_ATOMIC(op, T) \ - CUDA_ATOMIC_WRAPPER(op, T) { return cuda_builtin::atomic##op(address, val); } -#else #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } -#endif // For atomicAdd. USE_CUDA_ATOMIC(Add, int32); diff --git a/tensorflow/stream_executor/device_memory.h b/tensorflow/stream_executor/device_memory.h index eb73133d313..bcb0664b043 100644 --- a/tensorflow/stream_executor/device_memory.h +++ b/tensorflow/stream_executor/device_memory.h @@ -145,23 +145,6 @@ class DeviceMemory final : public DeviceMemoryBase { } // ------------------------------------------------------------ - // DO NOT USE - FASTR TEAM-INTERNAL FUNCTIONS - // Used internally by gcudacc. -#ifdef __GCUDACC__ - // Implicit conversion operators needed to support mixed mode. Since buffer - // sizes aren't used in the CUDA launching process, and since the constructed - // objects are all temporary, this is safe. - // Linter warning disabled as we require an implicit conversion. - DeviceMemory(const ElemT *opaque) : // NOLINT - DeviceMemoryBase(reinterpret_cast(const_cast(opaque)), - 0) {} - - operator ElemT *() { return reinterpret_cast(opaque()); } - operator const ElemT *() { - return const_cast(reinterpret_cast(opaque())); - } -#endif - // ------------------------------------------------------------ protected: // This constructor is solely used from derived classes; it is made protected -- GitLab