api_cache.h 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/**
 * \file dnn/src/cuda/api_cache.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#pragma once

#include "src/common/api_cache.h"
#include "src/cuda/cudnn_wrapper.h"

namespace megdnn {
19 20 21 22
class CudnnConvDescParam {
public:
    cudnnConvolutionDescriptor_t value;
    Empty serialize(StringSerializer& ser, Empty) {
23 24 25 26 27
        constexpr int maxNbDims = CUDNN_DIM_MAX - 2;
        int nbDims = maxNbDims;
        int padA[maxNbDims];
        int strideA[maxNbDims];
        int dilationA[maxNbDims];
28 29
        cudnnConvolutionMode_t mode;
        cudnnDataType_t computeType;
30 31 32
        cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA,
                                        strideA, dilationA, &mode,
                                        &computeType);
33 34 35 36 37
        ser.write_plain(nbDims);
        for (int i = 0; i < nbDims; ++i) {
            ser.write_plain(padA[i]);
            ser.write_plain(strideA[i]);
            ser.write_plain(dilationA[i]);
38
        }
39 40 41 42 43
        ser.write_plain(mode);
        ser.write_plain(computeType);
        return Empty{};
    }
};
44

45 46 47 48
class CudnnTensorDescParam {
public:
    cudnnTensorDescriptor_t value;
    Empty serialize(StringSerializer& ser, Empty) {
49
        int nbDims = MEGDNN_MAX_NDIM;
50 51 52
        cudnnDataType_t dataType;
        int dimA[MEGDNN_MAX_NDIM];
        int strideA[MEGDNN_MAX_NDIM];
53 54
        cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims,
                                   dimA, strideA);
55 56 57 58
        ser.write_plain(nbDims);
        for (int i = 0; i < nbDims; ++i) {
            ser.write_plain(dimA[i]);
            ser.write_plain(strideA[i]);
59
        }
60 61 62 63
        ser.write_plain(dataType);
        return Empty{};
    }
};
64

65 66 67 68
class CudnnFilterDescParam {
public:
    cudnnFilterDescriptor_t value;
    Empty serialize(StringSerializer& ser, Empty) {
69
        int nbDims = MEGDNN_MAX_NDIM;
70 71 72 73 74 75 76 77
        cudnnDataType_t dataType;
        cudnnTensorFormat_t format;
        int filterDimA[MEGDNN_MAX_NDIM];
        cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims,
                                   filterDimA);
        ser.write_plain(nbDims);
        for (int i = 0; i < nbDims; ++i) {
            ser.write_plain(filterDimA[i]);
78
        }
79 80 81 82
        ser.write_plain(dataType);
        ser.write_plain(format);
        return Empty{};
    }
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
};

template <typename T>
class CudnnConvAlgoPerfParam {
public:
    T value;
    Empty serialize(StringSerializer& ser, Empty) {
        ser.write_plain(value.algo);
        ser.write_plain(value.status);
        ser.write_plain(value.time);
        ser.write_plain(value.memory);
        ser.write_plain(value.determinism);
        ser.write_plain(value.mathType);
        return Empty{};
    }

99
    Empty deserialize(StringSerializer& ser, Empty) {
100 101 102 103 104 105
        ser.read_plain(&value.algo);
        ser.read_plain(&value.status);
        ser.read_plain(&value.time);
        ser.read_plain(&value.memory);
        ser.read_plain(&value.determinism);
        ser.read_plain(&value.mathType);
106 107 108 109
        return Empty{};
    }
};
}  // namespace megdnn