diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5fc2f28c13cc3fdab8f555e34ac0f48f81b6d333..51b07b39ef3811141a7cf0383b85cd357daf60aa 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -6,7 +6,13 @@ cmake_minimum_required(VERSION 3.12) -project(vecwise_engine) +project(vecwise_engine LANGUAGES CUDA CXX) + +find_package(CUDA) +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES -arch sm_60 --expt-extended-lambda") +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g") +message("CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}") +message("CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS}") set(CMAKE_CXX_STANDARD 14) diff --git a/cpp/src/CMakeLists.txt b/cpp/src/CMakeLists.txt index 62acba2f2e9ce60a31feced0ec2aeb0449e3189a..835a0d892023e1436d2905d683004868f8120435 100644 --- a/cpp/src/CMakeLists.txt +++ b/cpp/src/CMakeLists.txt @@ -27,7 +27,7 @@ find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64) add_library(vecwise_engine STATIC ${vecwise_engine_src}) -add_executable(vecwise_server +cuda_add_executable(vecwise_server ${config_files} ${server_files} ${utils_files} diff --git a/cpp/src/wrapper/Arithmetic.h b/cpp/src/wrapper/Arithmetic.h new file mode 100644 index 0000000000000000000000000000000000000000..6a49e8d33423adcc43f68013cdc980279310d4da --- /dev/null +++ b/cpp/src/wrapper/Arithmetic.h @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#pragma once + +#include +#include +#include +#include + + +namespace zilliz { +namespace vecwise { +namespace engine { + +using Bool = int8_t; +using Byte = uint8_t; +using Word = unsigned long; +using EnumType = uint64_t; + +using Float32 = float; +using Float64 = double; + +constexpr bool kBoolMax = std::numeric_limits::max(); +constexpr bool kBoolMin = std::numeric_limits::lowest(); + +constexpr int8_t kInt8Max = std::numeric_limits::max(); +constexpr int8_t kInt8Min = std::numeric_limits::lowest(); + +constexpr int16_t kInt16Max = std::numeric_limits::max(); +constexpr int16_t kInt16Min = std::numeric_limits::lowest(); + +constexpr int32_t kInt32Max = std::numeric_limits::max(); +constexpr int32_t kInt32Min = std::numeric_limits::lowest(); + +constexpr int64_t kInt64Max = std::numeric_limits::max(); +constexpr int64_t kInt64Min = std::numeric_limits::lowest(); + +constexpr float kFloatMax = std::numeric_limits::max(); +constexpr float kFloatMin = std::numeric_limits::lowest(); + +constexpr double kDoubleMax = std::numeric_limits::max(); +constexpr double kDoubleMin = std::numeric_limits::lowest(); + +constexpr uint32_t kFloat32DecimalPrecision = std::numeric_limits::digits10; +constexpr uint32_t kFloat64DecimalPrecision = std::numeric_limits::digits10; + + +constexpr uint8_t kByteWidth = 8; +constexpr uint8_t kCharWidth = kByteWidth; +constexpr uint8_t kWordWidth = sizeof(Word) * kByteWidth; +constexpr uint8_t kEnumTypeWidth = sizeof(EnumType) * kByteWidth; + +template +inline size_t +WidthOf() { return sizeof(T) << 3; } + +template +inline size_t +WidthOf(const T &) { return sizeof(T) << 3; } + + +} +} // namespace lib +} // namespace zilliz diff --git a/cpp/src/wrapper/Topk.cu b/cpp/src/wrapper/Topk.cu new file mode 100644 index 0000000000000000000000000000000000000000..10978c8151f559808b2afcb389a90b426eeca65b --- /dev/null +++ b/cpp/src/wrapper/Topk.cu @@ -0,0 +1,574 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// + +#include "faiss/FaissAssert.h" +#include "faiss/gpu/utils/Limits.cuh" +#include "Arithmetic.h" + + +namespace faiss { +namespace gpu { + +constexpr bool kBoolMax = zilliz::vecwise::engine::kBoolMax; +constexpr bool kBoolMin = zilliz::vecwise::engine::kBoolMin; + +template<> +struct Limits { + static __device__ __host__ + inline bool getMin() { + return kBoolMin; + } + static __device__ __host__ + inline bool getMax() { + return kBoolMax; + } +}; + +constexpr int8_t kInt8Max = zilliz::vecwise::engine::kInt8Max; +constexpr int8_t kInt8Min = zilliz::vecwise::engine::kInt8Min; + +template<> +struct Limits { + static __device__ __host__ + inline int8_t getMin() { + return kInt8Min; + } + static __device__ __host__ + inline int8_t getMax() { + return kInt8Max; + } +}; + +constexpr int16_t kInt16Max = zilliz::vecwise::engine::kInt16Max; +constexpr int16_t kInt16Min = zilliz::vecwise::engine::kInt16Min; + +template<> +struct Limits { + static __device__ __host__ + inline int16_t getMin() { + return kInt16Min; + } + static __device__ __host__ + inline int16_t getMax() { + return kInt16Max; + } +}; + +constexpr int64_t kInt64Max = zilliz::vecwise::engine::kInt64Max; +constexpr int64_t kInt64Min = zilliz::vecwise::engine::kInt64Min; + +template<> +struct Limits { + static __device__ __host__ + inline int64_t getMin() { + return kInt64Min; + } + static __device__ __host__ + inline int64_t getMax() { + return kInt64Max; + } +}; + +constexpr double kDoubleMax = zilliz::vecwise::engine::kDoubleMax; +constexpr double kDoubleMin = zilliz::vecwise::engine::kDoubleMin; + +template<> +struct Limits { + static __device__ __host__ + inline double getMin() { + return kDoubleMin; + } + static __device__ __host__ + inline double getMax() { + return kDoubleMax; + } +}; + +} +} + +#include "faiss/gpu/utils/DeviceUtils.h" +#include "faiss/gpu/utils/MathOperators.cuh" +#include "faiss/gpu/utils/Pair.cuh" +#include "faiss/gpu/utils/Reductions.cuh" +#include "faiss/gpu/utils/Select.cuh" +#include "faiss/gpu/utils/Tensor.cuh" +#include "faiss/gpu/utils/StaticUtils.h" + +#include "Topk.h" + + +namespace zilliz { +namespace vecwise { +namespace engine { +namespace gpu { + +constexpr int kWarpSize = 32; + +template +using Tensor = faiss::gpu::Tensor; + +template +using Pair = faiss::gpu::Pair; + + +// select kernel for k == 1 +template +__global__ void topkSelectMin1(Tensor productDistances, + Tensor outDistances, + Tensor outIndices) { + // Each block handles kRowsPerBlock rows of the distances (results) + Pair threadMin[kRowsPerBlock]; + __shared__ + Pair blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)]; + + T distance[kRowsPerBlock]; + +#pragma unroll + for (int i = 0; i < kRowsPerBlock; ++i) { + threadMin[i].k = faiss::gpu::Limits::getMax(); + threadMin[i].v = -1; + } + + // blockIdx.x: which chunk of rows we are responsible for updating + int rowStart = blockIdx.x * kRowsPerBlock; + + // FIXME: if we have exact multiples, don't need this + bool endRow = (blockIdx.x == gridDim.x - 1); + + if (endRow) { + if (productDistances.getSize(0) % kRowsPerBlock == 0) { + endRow = false; + } + } + + if (endRow) { + for (int row = rowStart; row < productDistances.getSize(0); ++row) { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + distance[0] = productDistances[row][col]; + + if (faiss::gpu::Math::lt(distance[0], threadMin[0].k)) { + threadMin[0].k = distance[0]; + threadMin[0].v = col; + } + } + + // Reduce within the block + threadMin[0] = + faiss::gpu::blockReduceAll, faiss::gpu::Min >, false, false>( + threadMin[0], faiss::gpu::Min >(), blockMin); + + if (threadIdx.x == 0) { + outDistances[row][0] = threadMin[0].k; + outIndices[row][0] = threadMin[0].v; + } + + // so we can use the shared memory again + __syncthreads(); + + threadMin[0].k = faiss::gpu::Limits::getMax(); + threadMin[0].v = -1; + } + } else { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + distance[row] = productDistances[rowStart + row][col]; + } + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + if (faiss::gpu::Math::lt(distance[row], threadMin[row].k)) { + threadMin[row].k = distance[row]; + threadMin[row].v = col; + } + } + } + + // Reduce within the block + faiss::gpu::blockReduceAll, faiss::gpu::Min >, false, false>( + threadMin, faiss::gpu::Min >(), blockMin); + + if (threadIdx.x == 0) { +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + outDistances[rowStart + row][0] = threadMin[row].k; + outIndices[rowStart + row][0] = threadMin[row].v; + } + } + } +} + +// L2 + select kernel for k > 1, no re-use of ||c||^2 +template +__global__ void topkSelectMinK(Tensor productDistances, + Tensor outDistances, + Tensor outIndices, + int k, T initK) { + // Each block handles a single row of the distances (results) + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ + T smemK[kNumWarps * NumWarpQ]; + __shared__ + int64_t smemV[kNumWarps * NumWarpQ]; + + faiss::gpu::BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, -1, smemK, smemV, k); + + int row = blockIdx.x; + + // Whole warps must participate in the selection + int limit = faiss::gpu::utils::roundDown(productDistances.getSize(1), kWarpSize); + int i = threadIdx.x; + + for (; i < limit; i += blockDim.x) { + T v = productDistances[row][i]; + heap.add(v, i); + } + + if (i < productDistances.getSize(1)) { + T v = productDistances[row][i]; + heap.addThreadQ(v, i); + } + + heap.reduce(); + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[row][i] = smemK[i]; + outIndices[row][i] = smemV[i]; + } +} + +// FIXME: no TVec specialization +template +void runTopKSelectMin(Tensor &productDistances, + Tensor &outDistances, + Tensor &outIndices, + int k, + cudaStream_t stream) { + FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0)); + FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0)); + FAISS_ASSERT(outDistances.getSize(1) == k); + FAISS_ASSERT(outIndices.getSize(1) == k); + FAISS_ASSERT(k <= 1024); + + if (k == 1) { + constexpr int kThreadsPerBlock = 256; + constexpr int kRowsPerBlock = 8; + + auto block = dim3(kThreadsPerBlock); + auto grid = dim3(faiss::gpu::utils::divUp(outDistances.getSize(0), kRowsPerBlock)); + + topkSelectMin1 + << < grid, block, 0, stream >> > (productDistances, outDistances, outIndices); + } else { + constexpr int kThreadsPerBlock = 128; + + auto block = dim3(kThreadsPerBlock); + auto grid = dim3(outDistances.getSize(0)); + +#define RUN_TOPK_SELECT_MIN(NUM_WARP_Q, NUM_THREAD_Q) \ + do { \ + topkSelectMinK \ + <<>>(productDistances, \ + outDistances, outIndices, \ + k, faiss::gpu::Limits::getMax()); \ + } while (0) + + if (k <= 32) { + RUN_TOPK_SELECT_MIN(32, 2); + } else if (k <= 64) { + RUN_TOPK_SELECT_MIN(64, 3); + } else if (k <= 128) { + RUN_TOPK_SELECT_MIN(128, 3); + } else if (k <= 256) { + RUN_TOPK_SELECT_MIN(256, 4); + } else if (k <= 512) { + RUN_TOPK_SELECT_MIN(512, 8); + } else if (k <= 1024) { + RUN_TOPK_SELECT_MIN(1024, 8); + } else { + FAISS_ASSERT(false); + } + } + + CUDA_TEST_ERROR(); +} + +//////////////////////////////////////////////////////////// +// select kernel for k == 1 +template +__global__ void topkSelectMax1(Tensor productDistances, + Tensor outDistances, + Tensor outIndices) { + // Each block handles kRowsPerBlock rows of the distances (results) + Pair threadMax[kRowsPerBlock]; + __shared__ + Pair blockMax[kRowsPerBlock * (kBlockSize / kWarpSize)]; + + T distance[kRowsPerBlock]; + +#pragma unroll + for (int i = 0; i < kRowsPerBlock; ++i) { + threadMax[i].k = faiss::gpu::Limits::getMin(); + threadMax[i].v = -1; + } + + // blockIdx.x: which chunk of rows we are responsible for updating + int rowStart = blockIdx.x * kRowsPerBlock; + + // FIXME: if we have exact multiples, don't need this + bool endRow = (blockIdx.x == gridDim.x - 1); + + if (endRow) { + if (productDistances.getSize(0) % kRowsPerBlock == 0) { + endRow = false; + } + } + + if (endRow) { + for (int row = rowStart; row < productDistances.getSize(0); ++row) { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + distance[0] = productDistances[row][col]; + + if (faiss::gpu::Math::gt(distance[0], threadMax[0].k)) { + threadMax[0].k = distance[0]; + threadMax[0].v = col; + } + } + + // Reduce within the block + threadMax[0] = + faiss::gpu::blockReduceAll, faiss::gpu::Max >, false, false>( + threadMax[0], faiss::gpu::Max >(), blockMax); + + if (threadIdx.x == 0) { + outDistances[row][0] = threadMax[0].k; + outIndices[row][0] = threadMax[0].v; + } + + // so we can use the shared memory again + __syncthreads(); + + threadMax[0].k = faiss::gpu::Limits::getMin(); + threadMax[0].v = -1; + } + } else { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + distance[row] = productDistances[rowStart + row][col]; + } + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + if (faiss::gpu::Math::gt(distance[row], threadMax[row].k)) { + threadMax[row].k = distance[row]; + threadMax[row].v = col; + } + } + } + + // Reduce within the block + faiss::gpu::blockReduceAll, faiss::gpu::Max >, false, false>( + threadMax, faiss::gpu::Max >(), blockMax); + + if (threadIdx.x == 0) { +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + outDistances[rowStart + row][0] = threadMax[row].k; + outIndices[rowStart + row][0] = threadMax[row].v; + } + } + } +} + +// L2 + select kernel for k > 1, no re-use of ||c||^2 +template +__global__ void topkSelectMaxK(Tensor productDistances, + Tensor outDistances, + Tensor outIndices, + int k, T initK) { + // Each block handles a single row of the distances (results) + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ + T smemK[kNumWarps * NumWarpQ]; + __shared__ + int64_t smemV[kNumWarps * NumWarpQ]; + + faiss::gpu::BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, -1, smemK, smemV, k); + + int row = blockIdx.x; + + // Whole warps must participate in the selection + int limit = faiss::gpu::utils::roundDown(productDistances.getSize(1), kWarpSize); + int i = threadIdx.x; + + for (; i < limit; i += blockDim.x) { + T v = productDistances[row][i]; + heap.add(v, i); + } + + if (i < productDistances.getSize(1)) { + T v = productDistances[row][i]; + heap.addThreadQ(v, i); + } + + heap.reduce(); + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[row][i] = smemK[i]; + outIndices[row][i] = smemV[i]; + } +} + +// FIXME: no TVec specialization +template +void runTopKSelectMax(Tensor &productDistances, + Tensor &outDistances, + Tensor &outIndices, + int k, + cudaStream_t stream) { + FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0)); + FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0)); + FAISS_ASSERT(outDistances.getSize(1) == k); + FAISS_ASSERT(outIndices.getSize(1) == k); + FAISS_ASSERT(k <= 1024); + + if (k == 1) { + constexpr int kThreadsPerBlock = 256; + constexpr int kRowsPerBlock = 8; + + auto block = dim3(kThreadsPerBlock); + auto grid = dim3(faiss::gpu::utils::divUp(outDistances.getSize(0), kRowsPerBlock)); + + topkSelectMax1 + << < grid, block, 0, stream >> > (productDistances, outDistances, outIndices); + } else { + constexpr int kThreadsPerBlock = 128; + + auto block = dim3(kThreadsPerBlock); + auto grid = dim3(outDistances.getSize(0)); + +#define RUN_TOPK_SELECT_MAX(NUM_WARP_Q, NUM_THREAD_Q) \ + do { \ + topkSelectMaxK \ + <<>>(productDistances, \ + outDistances, outIndices, \ + k, faiss::gpu::Limits::getMin()); \ + } while (0) + + if (k <= 32) { + RUN_TOPK_SELECT_MAX(32, 2); + } else if (k <= 64) { + RUN_TOPK_SELECT_MAX(64, 3); + } else if (k <= 128) { + RUN_TOPK_SELECT_MAX(128, 3); + } else if (k <= 256) { + RUN_TOPK_SELECT_MAX(256, 4); + } else if (k <= 512) { + RUN_TOPK_SELECT_MAX(512, 8); + } else if (k <= 1024) { + RUN_TOPK_SELECT_MAX(1024, 8); + } else { + FAISS_ASSERT(false); + } + } + + CUDA_TEST_ERROR(); +} +////////////////////////////////////////////////////////////// + +template +void runTopKSelect(Tensor &productDistances, + Tensor &outDistances, + Tensor &outIndices, + bool dir, + int k, + cudaStream_t stream) { + if (dir) { + runTopKSelectMax(productDistances, + outDistances, + outIndices, + k, + stream); + } else { + runTopKSelectMin(productDistances, + outDistances, + outIndices, + k, + stream); + } +} + +template +void TopK(T *input, + int length, + int k, + T *output, + int64_t *idx, +// Ordering order_flag, + cudaStream_t stream) { + +// bool dir = (order_flag == Ordering::kAscending ? false : true); + bool dir = 0; + + Tensor t_input(input, {1, length}); + Tensor t_output(output, {1, k}); + Tensor t_idx(idx, {1, k}); + + runTopKSelect(t_input, t_output, t_idx, dir, k, stream); +} + +//INSTANTIATION_TOPK_2(bool); +//INSTANTIATION_TOPK_2(int8_t); +//INSTANTIATION_TOPK_2(int16_t); +INSTANTIATION_TOPK_2(int32_t); +//INSTANTIATION_TOPK_2(int64_t); +INSTANTIATION_TOPK_2(float); +//INSTANTIATION_TOPK_2(double); +//INSTANTIATION_TOPK(TimeInterval); +//INSTANTIATION_TOPK(Float128); +//INSTANTIATION_TOPK(char); + +} + +void TopK(float *host_input, + int length, + int k, + float *output, + int64_t *indices) { + float *device_input, *device_output; + int64_t *ids; + + cudaMalloc((void **) &device_input, sizeof(float) * length); + cudaMalloc((void **) &device_output, sizeof(float) * k); + cudaMalloc((void **) &ids, sizeof(int64_t) * k); + + cudaMemcpy(device_input, host_input, sizeof(float) * length, cudaMemcpyHostToDevice); + + gpu::TopK(device_input, length, k, device_output, ids, nullptr); + + cudaMemcpy(output, device_output, sizeof(float) * k, cudaMemcpyDeviceToHost); + cudaMemcpy(indices, ids, sizeof(int64_t) * k, cudaMemcpyDeviceToHost); + + cudaFree(device_input); + cudaFree(device_output); + cudaFree(ids); +} + +} +} +} diff --git a/cpp/src/wrapper/Topk.h b/cpp/src/wrapper/Topk.h new file mode 100644 index 0000000000000000000000000000000000000000..ed6c233ff667e1bd9d07a8e60c18b945858a2eba --- /dev/null +++ b/cpp/src/wrapper/Topk.h @@ -0,0 +1,61 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include +#include + + +namespace zilliz { +namespace vecwise { +namespace engine { +namespace gpu { + +template +void +TopK(T *input, + int length, + int k, + T *output, + int64_t *indices, +// Ordering order_flag, + cudaStream_t stream = nullptr); + + +#define INSTANTIATION_TOPK_2(T) \ + template void \ + TopK(T *input, \ + int length, \ + int k, \ + T *output, \ + int64_t *indices, \ + cudaStream_t stream) +// Ordering order_flag, \ +// cudaStream_t stream) + +//extern INSTANTIATION_TOPK_2(int8_t); +//extern INSTANTIATION_TOPK_2(int16_t); +extern INSTANTIATION_TOPK_2(int32_t); +//extern INSTANTIATION_TOPK_2(int64_t); +extern INSTANTIATION_TOPK_2(float); +//extern INSTANTIATION_TOPK_2(double); +//extern INSTANTIATION_TOPK(TimeInterval); +//extern INSTANTIATION_TOPK(Float128); + +} + +// User Interface. +void TopK(float *input, + int length, + int k, + float *output, + int64_t *indices); + + +} +} +} diff --git a/cpp/unittest/faiss_wrapper/CMakeLists.txt b/cpp/unittest/faiss_wrapper/CMakeLists.txt index ed113e5070309cd39bdf00226d84f3c2e7fae695..5be724d3cee26c4af6a87a9bd1a74ebe65031d33 100644 --- a/cpp/unittest/faiss_wrapper/CMakeLists.txt +++ b/cpp/unittest/faiss_wrapper/CMakeLists.txt @@ -24,3 +24,10 @@ set(faiss_libs cublas ) target_link_libraries(wrapper_test ${unittest_libs} ${faiss_libs}) + +set(topk_test_src + topk_test.cpp + ${CMAKE_SOURCE_DIR}/src/wrapper/topk.cu) + +cuda_add_executable(topk_test ${topk_test_src}) +target_link_libraries(topk_test ${unittest_libs} ${faiss_libs}) diff --git a/cpp/unittest/faiss_wrapper/topk_test.cpp b/cpp/unittest/faiss_wrapper/topk_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6943829130267e5280df2c8096b6ae478fc63bd5 --- /dev/null +++ b/cpp/unittest/faiss_wrapper/topk_test.cpp @@ -0,0 +1,89 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// + +#include + +#include "wrapper/Topk.h" + + +using namespace zilliz::vecwise::engine; + +constexpr float threshhold = 0.00001; + +template +void TopK_check(T *data, + int length, + int k, + T *result) { + + std::vector arr(data, data + length); + sort(arr.begin(), arr.end(), std::less()); + + for (int i = 0; i < k; ++i) { + ASSERT_TRUE(fabs(arr[i] - result[i]) < threshhold); + } +} + +TEST(wrapper_topk, Wrapper_Test) { + int length = 100000; + int k = 1000; + + float *host_input, *host_output; + int64_t *ids; + + host_input = (float *) malloc(length * sizeof(float)); + host_output = (float *) malloc(k * sizeof(float)); + ids = (int64_t *) malloc(k * sizeof(int64_t)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(-1.0, 1.0); + for (int i = 0; i < length; ++i) { + host_input[i] = 1.0 * dis(gen); + } + + TopK(host_input, length, k, host_output, ids); + TopK_check(host_input, length, k, host_output); +} + +template +void TopK_Test(T factor) { + int length = 1000000; // data length + int k = 100; + + T *data, *out; + int64_t *idx; + cudaMallocManaged((void **) &data, sizeof(T) * length); + cudaMallocManaged((void **) &out, sizeof(T) * k); + cudaMallocManaged((void **) &idx, sizeof(int64_t) * k); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(-1.0, 1.0); + + for (int i = 0; i < length; i++) { + data[i] = factor * dis(gen); + } + + cudaMemAdvise(data, sizeof(T) * length, cudaMemAdviseSetReadMostly, 0); + + cudaMemPrefetchAsync(data, sizeof(T) * length, 0); + + gpu::TopK(data, length, k, out, idx, nullptr); + TopK_check(data, length, k, out); + +// order_flag = Ordering::kDescending; +// TopK(data, length, k, out, idx, nullptr); +// TopK_check(data, length, k, out); + + cudaFree(data); + cudaFree(out); + cudaFree(idx); +} + +TEST(topk_test, Wrapper_Test) { + TopK_Test(1.0); +}