L2Norm.cuh 725 字节
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */


#pragma once

#include <faiss/gpu/utils/Tensor.cuh>

namespace faiss { namespace gpu {

void runL2Norm(Tensor<float, 2, true>& input,
               bool inputRowMajor,
               Tensor<float, 1, true>& output,
               bool normSquared,
               cudaStream_t stream);

S
shengjun.li 已提交
21
#ifdef FAISS_USE_FLOAT16
J
JinHai-CN 已提交
22 23
void runL2Norm(Tensor<half, 2, true>& input,
               bool inputRowMajor,
C
Cai Yudong 已提交
24
               Tensor<float, 1, true>& output,
J
JinHai-CN 已提交
25 26
               bool normSquared,
               cudaStream_t stream);
S
shengjun.li 已提交
27
#endif
J
JinHai-CN 已提交
28 29

} } // namespace