/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include namespace faiss { namespace gpu { template struct Converter { }; #ifdef FAISS_USE_FLOAT16 template <> struct Converter { inline static __device__ half to(float v) { return __float2half(v); } }; #endif template <> struct Converter { inline static __device__ float to(float v) { return v; } }; // Kernel responsible for calculating distance from residual vector to // each product quantizer code centroid template __global__ void __launch_bounds__(288, 4) pqCodeDistances(Tensor queries, int queriesPerBlock, Tensor coarseCentroids, Tensor pqCentroids, Tensor topQueryToCentroid, // (query id)(coarse)(subquantizer)(code) -> dist Tensor outCodeDistances) { const auto numSubQuantizers = pqCentroids.getSize(0); const auto dimsPerSubQuantizer = pqCentroids.getSize(1); assert(DimsPerSubQuantizer == dimsPerSubQuantizer); const auto codesPerSubQuantizer = pqCentroids.getSize(2); bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer; int loadingThreadId = threadIdx.x - codesPerSubQuantizer; extern __shared__ float smem[]; // Each thread calculates a single code float subQuantizerData[DimsPerSubQuantizer]; auto code = threadIdx.x; auto subQuantizer = blockIdx.y; // Each thread will load the pq centroid data for the code that it // is processing #pragma unroll for (int i = 0; i < DimsPerSubQuantizer; ++i) { subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg(); } // Where we store our query vector float* smemQuery = smem; // Where we store our residual vector; this is double buffered so we // can be loading the next one while processing the current one float* smemResidual1 = &smemQuery[DimsPerSubQuantizer]; float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer]; // Where we pre-load the coarse centroid IDs int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer]; // Each thread is calculating the distance for a single code, // performing the reductions locally // Handle multiple queries per block auto startQueryId = blockIdx.x * queriesPerBlock; auto numQueries = queries.getSize(0) - startQueryId; if (numQueries > queriesPerBlock) { numQueries = queriesPerBlock; } for (int query = 0; query < numQueries; ++query) { auto queryId = startQueryId + query; auto querySubQuantizer = queries[queryId][subQuantizer * DimsPerSubQuantizer].data(); // Load current query vector for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) { smemQuery[i] = querySubQuantizer[i]; } // Load list of coarse centroids found for (int i = threadIdx.x; i < topQueryToCentroid.getSize(1); i += blockDim.x) { coarseIds[i] = topQueryToCentroid[queryId][i]; } // We need coarseIds below // FIXME: investigate loading separately, so we don't need this __syncthreads(); // Preload first buffer of residual data if (isLoadingThread) { for (int i = loadingThreadId; i < DimsPerSubQuantizer; i += blockDim.x - codesPerSubQuantizer) { auto coarseId = coarseIds[0]; // In case NaNs were in the original query data coarseId = coarseId == -1 ? 0 : coarseId; auto coarseCentroidSubQuantizer = coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data(); if (L2Distance) { smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i]; } else { smemResidual1[i] = coarseCentroidSubQuantizer[i]; } } } // The block walks the list for a single query for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) { // Wait for smemResidual1 to be loaded __syncthreads(); if (isLoadingThread) { // Preload second buffer of residual data for (int i = loadingThreadId; i < DimsPerSubQuantizer; i += blockDim.x - codesPerSubQuantizer) { // FIXME: try always making this centroid id 0 so we can // terminate if (coarse != (topQueryToCentroid.getSize(1) - 1)) { auto coarseId = coarseIds[coarse + 1]; // In case NaNs were in the original query data coarseId = coarseId == -1 ? 0 : coarseId; auto coarseCentroidSubQuantizer = coarseCentroids[coarseId] [subQuantizer * dimsPerSubQuantizer].data(); if (L2Distance) { smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i]; } else { smemResidual2[i] = coarseCentroidSubQuantizer[i]; } } } } else { // These are the processing threads float dist = 0.0f; constexpr int kUnroll = 4; constexpr int kRemainder = DimsPerSubQuantizer % kUnroll; constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder; float vals[kUnroll]; // Calculate residual - pqCentroid for each dim that we're // processing // Unrolled loop if (L2Distance) { #pragma unroll for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] = smemResidual1[i * kUnroll + j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] -= subQuantizerData[i * kUnroll + j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] *= vals[j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { dist += vals[j]; } } } else { // Inner product: query slice against the reconstructed sub-quantizer // for this coarse cell (query o (centroid + subQCentroid)) #pragma unroll for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] = smemResidual1[i * kUnroll + j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] += subQuantizerData[i * kUnroll + j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { vals[j] *= smemQuery[i * kUnroll + j]; } #pragma unroll for (int j = 0; j < kUnroll; ++j) { dist += vals[j]; } } } // Remainder loop if (L2Distance) { #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] = smemResidual1[kRemainderBase + j]; } #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] -= subQuantizerData[kRemainderBase + j]; } #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] *= vals[j]; } } else { // Inner product // Inner product: query slice against the reconstructed sub-quantizer // for this coarse cell (query o (centroid + subQCentroid)) #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] = smemResidual1[kRemainderBase + j]; } #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] += subQuantizerData[kRemainderBase + j]; } #pragma unroll for (int j = 0; j < kRemainder; ++j) { vals[j] *= smemQuery[kRemainderBase + j]; } } #pragma unroll for (int j = 0; j < kRemainder; ++j) { dist += vals[j]; } // We have the distance for our code; write it out outCodeDistances[queryId][coarse][subQuantizer][code] = Converter::to(dist); } // !isLoadingThread // Swap residual buffers float* tmp = smemResidual1; smemResidual1 = smemResidual2; smemResidual2 = tmp; } } } __global__ void residualVector(Tensor queries, Tensor coarseCentroids, Tensor topQueryToCentroid, int numSubDim, // output is transposed: // (sub q)(query id)(centroid id)(sub dim) Tensor residual) { // block x is query id // block y is centroid id // thread x is dim auto queryId = blockIdx.x; auto centroidId = blockIdx.y; int realCentroidId = topQueryToCentroid[queryId][centroidId]; for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) { float q = queries[queryId][dim]; float c = coarseCentroids[realCentroidId][dim]; residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = q - c; } } void runResidualVector(Tensor& pqCentroids, Tensor& queries, Tensor& coarseCentroids, Tensor& topQueryToCentroid, Tensor& residual, cudaStream_t stream) { auto grid = dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1)); auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice())); residualVector<<>>( queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1), residual); CUDA_TEST_ERROR(); } void runPQCodeDistancesMM(Tensor& pqCentroids, Tensor& queries, Tensor& coarseCentroids, Tensor& topQueryToCentroid, NoTypeTensor<4, true>& outCodeDistances, bool useFloat16Lookup, DeviceMemory& mem, cublasHandle_t handle, cudaStream_t stream) { // Calculate (q - c) residual vector // (sub q)(query id)(centroid id)(sub dim) DeviceTensor residual( mem, {pqCentroids.getSize(0), topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1), pqCentroids.getSize(1)}, stream); runResidualVector(pqCentroids, queries, coarseCentroids, topQueryToCentroid, residual, stream); // Calculate ||q - c||^2 DeviceTensor residualNorms( mem, {pqCentroids.getSize(0) * topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1)}, stream); auto residualView2 = residual.view<2>( {pqCentroids.getSize(0) * topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), pqCentroids.getSize(1)}); runL2Norm(residualView2, true, residualNorms, true, stream); // Perform a batch MM: // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} => // (sub q) x {(q * c)(code)} auto residualView3 = residual.view<3>( {pqCentroids.getSize(0), topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), pqCentroids.getSize(1)}); DeviceTensor residualDistance( mem, {pqCentroids.getSize(0), topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), pqCentroids.getSize(2)}, stream); runIteratedMatrixMult(residualDistance, false, residualView3, false, pqCentroids, false, -2.0f, 0.0f, handle, stream); // Sum ||q - c||^2 along rows auto residualDistanceView2 = residualDistance.view<2>( {pqCentroids.getSize(0) * topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), pqCentroids.getSize(2)}); runSumAlongRows(residualNorms, residualDistanceView2, false, stream); Tensor outCodeDistancesF; DeviceTensor outCodeDistancesFloatMem; #ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { outCodeDistancesFloatMem = DeviceTensor( mem, {outCodeDistances.getSize(0), outCodeDistances.getSize(1), outCodeDistances.getSize(2), outCodeDistances.getSize(3)}, stream); outCodeDistancesF = outCodeDistancesFloatMem; } else { outCodeDistancesF = outCodeDistances.toTensor(); } #else outCodeDistancesF = outCodeDistances.toTensor(); #endif // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which // is where we build our output distances) auto outCodeDistancesView = outCodeDistancesF.view<3>( {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), outCodeDistances.getSize(2), outCodeDistances.getSize(3)}); runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream); // Calculate code norms per each sub-dim // (sub q)(sub dim)(code) is pqCentroids // transpose to (sub q)(code)(sub dim) DeviceTensor pqCentroidsTranspose( mem, {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)}, stream); runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream); auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>( {pqCentroids.getSize(0) * pqCentroids.getSize(2), pqCentroids.getSize(1)}); DeviceTensor pqCentroidsNorm( mem, {pqCentroids.getSize(0) * pqCentroids.getSize(2)}, stream); runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream); // View output as (q * c)(sub q * code), and add centroid norm to // each row auto outDistancesCodeViewCols = outCodeDistancesView.view<2>( {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), outCodeDistances.getSize(2) * outCodeDistances.getSize(3)}); runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream); #ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { // Need to convert back auto outCodeDistancesH = outCodeDistances.toTensor(); convertTensor(stream, outCodeDistancesF, outCodeDistancesH); } #endif } void runPQCodeDistances(Tensor& pqCentroids, Tensor& queries, Tensor& coarseCentroids, Tensor& topQueryToCentroid, NoTypeTensor<4, true>& outCodeDistances, bool l2Distance, bool useFloat16Lookup, cudaStream_t stream) { const auto numSubQuantizers = pqCentroids.getSize(0); const auto dimsPerSubQuantizer = pqCentroids.getSize(1); const auto codesPerSubQuantizer = pqCentroids.getSize(2); // FIXME: tune // Reuse of pq centroid data is based on both # of queries * nprobe, // and we should really be tiling in both dimensions constexpr int kQueriesPerBlock = 8; auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock), numSubQuantizers); // Reserve one block of threads for double buffering // FIXME: probably impractical for large # of dims? auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize); auto block = dim3(codesPerSubQuantizer + loadingThreads); auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + topQueryToCentroid.getSize(1) * sizeof(int); #ifdef FAISS_USE_FLOAT16 #define RUN_CODE(DIMS, L2) \ do { \ if (useFloat16Lookup) { \ auto outCodeDistancesT = outCodeDistances.toTensor(); \ \ pqCodeDistances<<>>( \ queries, kQueriesPerBlock, \ coarseCentroids, pqCentroids, \ topQueryToCentroid, outCodeDistancesT); \ } else { \ auto outCodeDistancesT = outCodeDistances.toTensor(); \ \ pqCodeDistances<<>>( \ queries, kQueriesPerBlock, \ coarseCentroids, pqCentroids, \ topQueryToCentroid, outCodeDistancesT); \ } \ } while (0) #else #define RUN_CODE(DIMS, L2) \ do { \ if(!useFloat16Lookup){ \ auto outCodeDistancesT = outCodeDistances.toTensor(); \ \ pqCodeDistances<<>>( \ queries, kQueriesPerBlock, \ coarseCentroids, pqCentroids, \ topQueryToCentroid, outCodeDistancesT); \ } \ } while (0) #endif #define CODE_L2(DIMS) \ do { \ if (l2Distance) { \ RUN_CODE(DIMS, true); \ } else { \ RUN_CODE(DIMS, false); \ } \ } while (0) switch (dimsPerSubQuantizer) { case 1: CODE_L2(1); break; case 2: CODE_L2(2); break; case 3: CODE_L2(3); break; case 4: CODE_L2(4); break; case 6: CODE_L2(6); break; case 8: CODE_L2(8); break; case 10: CODE_L2(10); break; case 12: CODE_L2(12); break; case 16: CODE_L2(16); break; case 20: CODE_L2(20); break; case 24: CODE_L2(24); break; case 28: CODE_L2(28); break; case 32: CODE_L2(32); break; // FIXME: larger sizes require too many registers - we need the // MM implementation working default: FAISS_THROW_MSG("Too many dimensions (>32) per subquantizer " "not currently supported"); } #undef RUN_CODE #undef CODE_L2 CUDA_TEST_ERROR(); } } } // namespace