/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include namespace faiss { namespace gpu { // // Conversion utilities // template struct Convert { inline __device__ To operator()(From v) const { return (To) v; } }; #ifdef FAISS_USE_FLOAT16 template <> struct Convert { inline __device__ half operator()(float v) const { return __float2half(v); } }; template <> struct Convert { inline __device__ float operator()(half v) const { return __half2float(v); } }; #endif template struct ConvertTo { }; template <> struct ConvertTo { static inline __device__ float to(float v) { return v; } #ifdef FAISS_USE_FLOAT16 static inline __device__ float to(half v) { return __half2float(v); } #endif }; template <> struct ConvertTo { static inline __device__ float2 to(float2 v) { return v; } #ifdef FAISS_USE_FLOAT16 static inline __device__ float2 to(half2 v) { return __half22float2(v); } #endif }; template <> struct ConvertTo { static inline __device__ float4 to(float4 v) { return v; } #ifdef FAISS_USE_FLOAT16 static inline __device__ float4 to(Half4 v) { return half4ToFloat4(v); } #endif }; #ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ half to(float v) { return __float2half(v); } static inline __device__ half to(half v) { return v; } }; #endif #ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ half2 to(float2 v) { return __float22half2_rn(v); } static inline __device__ half2 to(half2 v) { return v; } }; #endif #ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ Half4 to(float4 v) { return float4ToHalf4(v); } static inline __device__ Half4 to(Half4 v) { return v; } }; #endif // Tensor conversion template void runConvert(const From* in, To* out, size_t num, cudaStream_t stream) { thrust::transform(thrust::cuda::par.on(stream), in, in + num, out, Convert()); } template void convertTensor(cudaStream_t stream, Tensor& in, Tensor& out) { FAISS_ASSERT(in.numElements() == out.numElements()); runConvert(in.data(), out.data(), in.numElements(), stream); } template DeviceTensor convertTensor(GpuResources* res, cudaStream_t stream, Tensor& in) { DeviceTensor out; if (res) { out = std::move(DeviceTensor( res->getMemoryManagerCurrentDevice(), in.sizes(), stream)); } else { out = std::move(DeviceTensor(in.sizes())); } convertTensor(stream, in, out); return out; } } } // namespace