ConversionOperators.cuh 3.3 KB
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8 9 10
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */


#pragma once

C
Cai Yudong 已提交
11
#include <faiss/MetricType.h>
J
JinHai-CN 已提交
12
#include <faiss/gpu/utils/DeviceTensor.cuh>
C
Cai Yudong 已提交
13
#include <faiss/gpu/utils/Float16.cuh>
J
JinHai-CN 已提交
14

C
Cai Yudong 已提交
15
#include <cuda.h>
J
JinHai-CN 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include <thrust/execution_policy.h>
#include <thrust/transform.h>

namespace faiss { namespace gpu {

//
// Conversion utilities
//

template <typename From, typename To>
struct Convert {
  inline __device__ To operator()(From v) const {
    return (To) v;
  }
};

S
shengjun.li 已提交
32
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45
template <>
struct Convert<float, half> {
  inline __device__ half operator()(float v) const {
    return __float2half(v);
  }
};

template <>
struct Convert<half, float> {
  inline __device__ float operator()(half v) const {
    return __half2float(v);
  }
};
S
shengjun.li 已提交
46
#endif
J
JinHai-CN 已提交
47 48 49 50 51 52 53 54

template <typename T>
struct ConvertTo {
};

template <>
struct ConvertTo<float> {
  static inline __device__ float to(float v) { return v; }
S
shengjun.li 已提交
55
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
56
  static inline __device__ float to(half v) { return __half2float(v); }
S
shengjun.li 已提交
57
#endif
J
JinHai-CN 已提交
58 59 60 61 62
};

template <>
struct ConvertTo<float2> {
  static inline __device__ float2 to(float2 v) { return v; }
S
shengjun.li 已提交
63
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
64
  static inline __device__ float2 to(half2 v) { return __half22float2(v); }
S
shengjun.li 已提交
65
#endif
J
JinHai-CN 已提交
66 67 68 69 70
};

template <>
struct ConvertTo<float4> {
  static inline __device__ float4 to(float4 v) { return v; }
S
shengjun.li 已提交
71
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
72
  static inline __device__ float4 to(Half4 v) { return half4ToFloat4(v); }
S
shengjun.li 已提交
73
#endif
J
JinHai-CN 已提交
74 75
};

S
shengjun.li 已提交
76
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
77 78 79 80 81
template <>
struct ConvertTo<half> {
  static inline __device__ half to(float v) { return __float2half(v); }
  static inline __device__ half to(half v) { return v; }
};
S
shengjun.li 已提交
82
#endif
J
JinHai-CN 已提交
83

S
shengjun.li 已提交
84
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
85 86 87 88 89
template <>
struct ConvertTo<half2> {
  static inline __device__ half2 to(float2 v) { return __float22half2_rn(v); }
  static inline __device__ half2 to(half2 v) { return v; }
};
S
shengjun.li 已提交
90
#endif
J
JinHai-CN 已提交
91

S
shengjun.li 已提交
92
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
93 94 95 96 97
template <>
struct ConvertTo<Half4> {
  static inline __device__ Half4 to(float4 v) { return float4ToHalf4(v); }
  static inline __device__ Half4 to(Half4 v) { return v; }
};
S
shengjun.li 已提交
98
#endif
J
JinHai-CN 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

// Tensor conversion
template <typename From, typename To>
void runConvert(const From* in,
                To* out,
                size_t num,
                cudaStream_t stream) {
  thrust::transform(thrust::cuda::par.on(stream),
                    in, in + num, out, Convert<From, To>());
}

template <typename From, typename To, int Dim>
void convertTensor(cudaStream_t stream,
                   Tensor<From, Dim, true>& in,
                   Tensor<To, Dim, true>& out) {
  FAISS_ASSERT(in.numElements() == out.numElements());

  runConvert<From, To>(in.data(), out.data(), in.numElements(), stream);
}

template <typename From, typename To, int Dim>
DeviceTensor<To, Dim, true> convertTensor(GpuResources* res,
                                          cudaStream_t stream,
                                          Tensor<From, Dim, true>& in) {
  DeviceTensor<To, Dim, true> out;

  if (res) {
    out = std::move(DeviceTensor<To, Dim, true>(
                      res->getMemoryManagerCurrentDevice(),
                      in.sizes(),
                      stream));
  } else {
    out = std::move(DeviceTensor<To, Dim, true>(in.sizes()));
  }

  convertTensor(stream, in, out);
  return out;
}

} } // namespace