utils.cpp 1.4 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/**
 * \file src/nccl/utils.cpp
 * MegRay is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "utils.h"

namespace MegRay {

ncclDataType_t get_nccl_dtype(const DType dtype) {
    switch (dtype) {
        case MEGRAY_INT8:
            return ncclInt8;
        case MEGRAY_UINT8:
            return ncclUint8;
        case MEGRAY_INT32:
            return ncclInt32;
        case MEGRAY_UINT32:
            return ncclUint32;
        case MEGRAY_INT64:
            return ncclInt64;
        case MEGRAY_UINT64:
            return ncclUint64;
        case MEGRAY_FLOAT16:
            return ncclFloat16;
        case MEGRAY_FLOAT32:
            return ncclFloat32;
        case MEGRAY_FLOAT64:
            return ncclFloat64;
        default:
            MEGRAY_THROW("unknown dtype");
    }
}

ncclRedOp_t get_nccl_reduce_op(const ReduceOp red_op) {
    switch (red_op) {
        case MEGRAY_SUM:
            return ncclSum;
        case MEGRAY_MAX:
            return ncclMax;
        case MEGRAY_MIN:
            return ncclMin;
        default:
            MEGRAY_THROW("unknown reduce op");
    }
}

} // namespace MegRay