/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include namespace faiss { namespace gpu { // For precomputed codes, this calculates and loads code distances // into smem template inline __device__ void loadPrecomputedTerm(LookupT* smem, LookupT* term2Start, LookupT* term3Start, int numCodes) { constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT); // We can only use vector loads if the data is guaranteed to be // aligned. The codes are innermost, so if it is evenly divisible, // then any slice will be aligned. if (numCodes % kWordSize == 0) { constexpr int kUnroll = 2; // Load the data by float4 for efficiency, and then handle any remainder // limitVec is the number of whole vec words we can load, in terms // of whole blocks performing the load int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x); limitVec *= kUnroll * blockDim.x; LookupVecT* smemV = (LookupVecT*) smem; LookupVecT* term2StartV = (LookupVecT*) term2Start; LookupVecT* term3StartV = (LookupVecT*) term3Start; for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) { LookupVecT vals[kUnroll]; #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] = LoadStore::load(&term2StartV[i + j * blockDim.x]); } #pragma unroll for (int j = 0; j < kUnroll; ++j) { LookupVecT q = LoadStore::load(&term3StartV[i + j * blockDim.x]); vals[j] = Math::add(vals[j], q); } #pragma unroll for (int j = 0; j < kUnroll; ++j) { LoadStore::store(&smemV[i + j * blockDim.x], vals[j]); } } // This is where we start loading the remainder that does not evenly // fit into kUnroll x blockDim.x int remainder = limitVec * kWordSize; for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) { smem[i] = Math::add(term2Start[i], term3Start[i]); } } else { // Potential unaligned load constexpr int kUnroll = 4; int limit = utils::roundDown(numCodes, kUnroll * blockDim.x); int i = threadIdx.x; for (; i < limit; i += kUnroll * blockDim.x) { LookupT vals[kUnroll]; #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] = term2Start[i + j * blockDim.x]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] = Math::add(vals[j], term3Start[i + j * blockDim.x]); } #pragma unroll for (int j = 0; j < kUnroll; ++j) { smem[i + j * blockDim.x] = vals[j]; } } for (; i < numCodes; i += blockDim.x) { smem[i] = Math::add(term2Start[i], term3Start[i]); } } } template __global__ void pqScanPrecomputedMultiPass(Tensor queries, Tensor precompTerm1, Tensor precompTerm2, Tensor precompTerm3, Tensor topQueryToCentroid, void** listCodes, int* listLengths, Tensor prefixSumOffsets, Tensor distance) { // precomputed term 2 + 3 storage // (sub q)(code id) extern __shared__ char smemTerm23[]; LookupT* term23 = (LookupT*) smemTerm23; // Each block handles a single query auto queryId = blockIdx.y; auto probeId = blockIdx.x; auto codesPerSubQuantizer = precompTerm2.getSize(2); auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer; // This is where we start writing out data // We ensure that before the array (at offset -1), there is a 0 value int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1); float* distanceOut = distance[outBase].data(); auto listId = topQueryToCentroid[queryId][probeId]; // Safety guard in case NaNs in input cause no list ID to be generated if (listId == -1) { return; } unsigned char* codeList = (unsigned char*) listCodes[listId]; int limit = listLengths[listId]; constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 : (NumSubQuantizers / 4); unsigned int code32[kNumCode32]; unsigned int nextCode32[kNumCode32]; // We double-buffer the code loading, which improves memory utilization if (threadIdx.x < limit) { LoadCode32::load(code32, codeList, threadIdx.x); } // Load precomputed terms 1, 2, 3 float term1 = precompTerm1[queryId][probeId]; loadPrecomputedTerm(term23, precompTerm2[listId].data(), precompTerm3[queryId].data(), precompTermSize); // Prevent WAR dependencies __syncthreads(); // Each thread handles one code element in the list, with a // block-wide stride for (int codeIndex = threadIdx.x; codeIndex < limit; codeIndex += blockDim.x) { // Prefetch next codes if (codeIndex + blockDim.x < limit) { LoadCode32::load( nextCode32, codeList, codeIndex + blockDim.x); } float dist = term1; #pragma unroll for (int word = 0; word < kNumCode32; ++word) { constexpr int kBytesPerCode32 = NumSubQuantizers < 4 ? NumSubQuantizers : 4; if (kBytesPerCode32 == 1) { auto code = code32[0]; dist = ConvertTo::to(term23[code]); } else { #pragma unroll for (int byte = 0; byte < kBytesPerCode32; ++byte) { auto code = getByte(code32[word], byte * 8, 8); auto offset = codesPerSubQuantizer * (word * kBytesPerCode32 + byte); dist += ConvertTo::to(term23[offset + code]); } } } // Write out intermediate distance result // We do not maintain indices here, in order to reduce global // memory traffic. Those are recovered in the final selection step. distanceOut[codeIndex] = dist; // Rotate buffers #pragma unroll for (int word = 0; word < kNumCode32; ++word) { code32[word] = nextCode32[word]; } } } void runMultiPassTile(Tensor& queries, Tensor& precompTerm1, NoTypeTensor<3, true>& precompTerm2, NoTypeTensor<3, true>& precompTerm3, Tensor& topQueryToCentroid, Tensor& bitset, bool useFloat16Lookup, int bytesPerCode, int numSubQuantizers, int numSubQuantizerCodes, thrust::device_vector& listCodes, thrust::device_vector& listIndices, IndicesOptions indicesOptions, thrust::device_vector& listLengths, Tensor& thrustMem, Tensor& prefixSumOffsets, Tensor& allDistances, Tensor& heapDistances, Tensor& heapIndices, int k, Tensor& outDistances, Tensor& outIndices, cudaStream_t stream) { // Calculate offset lengths, so we know where to write out // intermediate results runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets, thrustMem, stream); // Convert all codes to a distance, and write out (distance, // index) values for all intermediate results { auto kThreadsPerBlock = 256; auto grid = dim3(topQueryToCentroid.getSize(1), topQueryToCentroid.getSize(0)); auto block = dim3(kThreadsPerBlock); // pq precomputed terms (2 + 3) auto smem = sizeof(float); #ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { smem = sizeof(half); } #endif smem *= numSubQuantizers * numSubQuantizerCodes; FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \ do { \ auto precompTerm2T = precompTerm2.toTensor(); \ auto precompTerm3T = precompTerm3.toTensor(); \ \ pqScanPrecomputedMultiPass \ <<>>( \ queries, \ precompTerm1, \ precompTerm2T, \ precompTerm3T, \ topQueryToCentroid, \ listCodes.data().get(), \ listLengths.data().get(), \ prefixSumOffsets, \ allDistances); \ } while (0) #ifdef FAISS_USE_FLOAT16 #define RUN_PQ(NUM_SUB_Q) \ do { \ if (useFloat16Lookup) { \ RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \ } else { \ RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ } \ } while (0) #else #define RUN_PQ(NUM_SUB_Q) \ do { \ RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ } while (0) #endif switch (bytesPerCode) { case 1: RUN_PQ(1); break; case 2: RUN_PQ(2); break; case 3: RUN_PQ(3); break; case 4: RUN_PQ(4); break; case 8: RUN_PQ(8); break; case 12: RUN_PQ(12); break; case 16: RUN_PQ(16); break; case 20: RUN_PQ(20); break; case 24: RUN_PQ(24); break; case 28: RUN_PQ(28); break; case 32: RUN_PQ(32); break; case 40: RUN_PQ(40); break; case 48: RUN_PQ(48); break; case 56: RUN_PQ(56); break; case 64: RUN_PQ(64); break; case 96: RUN_PQ(96); break; default: FAISS_ASSERT(false); break; } CUDA_TEST_ERROR(); #undef RUN_PQ #undef RUN_PQ_OPT } // k-select the output in chunks, to increase parallelism runPass1SelectLists(listIndices, indicesOptions, prefixSumOffsets, topQueryToCentroid, bitset, allDistances, topQueryToCentroid.getSize(1), k, false, // L2 distance chooses smallest heapDistances, heapIndices, stream); // k-select final output auto flatHeapDistances = heapDistances.downcastInner<2>(); auto flatHeapIndices = heapIndices.downcastInner<2>(); runPass2SelectLists(flatHeapDistances, flatHeapIndices, listIndices, indicesOptions, prefixSumOffsets, topQueryToCentroid, k, false, // L2 distance chooses smallest outDistances, outIndices, stream); CUDA_TEST_ERROR(); } void runPQScanMultiPassPrecomputed(Tensor& queries, Tensor& precompTerm1, NoTypeTensor<3, true>& precompTerm2, NoTypeTensor<3, true>& precompTerm3, Tensor& topQueryToCentroid, Tensor& bitset, bool useFloat16Lookup, int bytesPerCode, int numSubQuantizers, int numSubQuantizerCodes, thrust::device_vector& listCodes, thrust::device_vector& listIndices, IndicesOptions indicesOptions, thrust::device_vector& listLengths, int maxListLength, int k, // output Tensor& outDistances, // output Tensor& outIndices, GpuResources* res) { constexpr int kMinQueryTileSize = 8; constexpr int kMaxQueryTileSize = 128; constexpr int kThrustMemSize = 16384; int nprobe = topQueryToCentroid.getSize(1); auto& mem = res->getMemoryManagerCurrentDevice(); auto stream = res->getDefaultStreamCurrentDevice(); // Make a reservation for Thrust to do its dirty work (global memory // cross-block reduction space); hopefully this is large enough. DeviceTensor thrustMem1( mem, {kThrustMemSize}, stream); DeviceTensor thrustMem2( mem, {kThrustMemSize}, stream); DeviceTensor* thrustMem[2] = {&thrustMem1, &thrustMem2}; // How much temporary storage is available? // If possible, we'd like to fit within the space available. size_t sizeAvailable = mem.getSizeAvailable(); // We run two passes of heap selection // This is the size of the first-level heap passes constexpr int kNProbeSplit = 8; int pass2Chunks = std::min(nprobe, kNProbeSplit); size_t sizeForFirstSelectPass = pass2Chunks * k * (sizeof(float) + sizeof(int)); // How much temporary storage we need per each query size_t sizePerQuery = 2 * // # streams ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets nprobe * maxListLength * sizeof(float) + // allDistances sizeForFirstSelectPass); int queryTileSize = (int) (sizeAvailable / sizePerQuery); if (queryTileSize < kMinQueryTileSize) { queryTileSize = kMinQueryTileSize; } else if (queryTileSize > kMaxQueryTileSize) { queryTileSize = kMaxQueryTileSize; } // FIXME: we should adjust queryTileSize to deal with this, since // indexing is in int32 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <= std::numeric_limits::max()); // Temporary memory buffers // Make sure there is space prior to the start which will be 0, and // will handle the boundary condition without branches DeviceTensor prefixSumOffsetSpace1( mem, {queryTileSize * nprobe + 1}, stream); DeviceTensor prefixSumOffsetSpace2( mem, {queryTileSize * nprobe + 1}, stream); DeviceTensor prefixSumOffsets1( prefixSumOffsetSpace1[1].data(), {queryTileSize, nprobe}); DeviceTensor prefixSumOffsets2( prefixSumOffsetSpace2[1].data(), {queryTileSize, nprobe}); DeviceTensor* prefixSumOffsets[2] = {&prefixSumOffsets1, &prefixSumOffsets2}; // Make sure the element before prefixSumOffsets is 0, since we // depend upon simple, boundary-less indexing to get proper results CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(), 0, sizeof(int), stream)); CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(), 0, sizeof(int), stream)); DeviceTensor allDistances1( mem, {queryTileSize * nprobe * maxListLength}, stream); DeviceTensor allDistances2( mem, {queryTileSize * nprobe * maxListLength}, stream); DeviceTensor* allDistances[2] = {&allDistances1, &allDistances2}; DeviceTensor heapDistances1( mem, {queryTileSize, pass2Chunks, k}, stream); DeviceTensor heapDistances2( mem, {queryTileSize, pass2Chunks, k}, stream); DeviceTensor* heapDistances[2] = {&heapDistances1, &heapDistances2}; DeviceTensor heapIndices1( mem, {queryTileSize, pass2Chunks, k}, stream); DeviceTensor heapIndices2( mem, {queryTileSize, pass2Chunks, k}, stream); DeviceTensor* heapIndices[2] = {&heapIndices1, &heapIndices2}; auto streams = res->getAlternateStreamsCurrentDevice(); streamWait(streams, {stream}); int curStream = 0; for (int query = 0; query < queries.getSize(0); query += queryTileSize) { int numQueriesInTile = std::min(queryTileSize, queries.getSize(0) - query); auto prefixSumOffsetsView = prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile); auto coarseIndicesView = topQueryToCentroid.narrowOutermost(query, numQueriesInTile); auto queryView = queries.narrowOutermost(query, numQueriesInTile); auto term1View = precompTerm1.narrowOutermost(query, numQueriesInTile); auto term3View = precompTerm3.narrowOutermost(query, numQueriesInTile); auto heapDistancesView = heapDistances[curStream]->narrowOutermost(0, numQueriesInTile); auto heapIndicesView = heapIndices[curStream]->narrowOutermost(0, numQueriesInTile); auto outDistanceView = outDistances.narrowOutermost(query, numQueriesInTile); auto outIndicesView = outIndices.narrowOutermost(query, numQueriesInTile); runMultiPassTile(queryView, term1View, precompTerm2, term3View, coarseIndicesView, bitset, useFloat16Lookup, bytesPerCode, numSubQuantizers, numSubQuantizerCodes, listCodes, listIndices, indicesOptions, listLengths, *thrustMem[curStream], prefixSumOffsetsView, *allDistances[curStream], heapDistancesView, heapIndicesView, k, outDistanceView, outIndicesView, streams[curStream]); curStream = (curStream + 1) % 2; } streamWait({stream}, streams); } } } // namespace