local.cuh 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/**
 * \file dnn/src/cuda/local/local.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include <cuda_runtime_api.h>
#include <cublas_v2.h>

namespace megdnn {
namespace cuda {
namespace local {

20
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW);
21

22 23
void forward_proxy_default(const float *src, const float *filter, float *dst,
        size_t N,
24 25 26 27 28 29 30 31 32 33 34
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW,
        bool is_xcorr,
        cudaStream_t stream);

/// forward

35
bool can_forward_proxy_convnet(size_t N,
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

void forward_proxy_convnet(const float *src, const float *filter, float *dst,
        float *workspace,
        size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs, // IN stride and ON stride
        size_t PH, size_t PW,
        size_t SH, size_t SW,
        cublasHandle_t cublas_handle,
        cudaStream_t stream,
        float *one, float *zero);

size_t get_workspace_in_floats_forward_proxy_convnet(size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

/// bwd data

66
bool can_backward_data_proxy_convnet(size_t N,
67 68 69 70 71 72 73
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

74
void backward_data_proxy_convnet(const float *filter,
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        const float *diff,
        float *grad,
        float *workspace,
        size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs, // IN stride and ON stride
        size_t PH, size_t PW,
        size_t SH, size_t SW,
        cublasHandle_t cublas_handle,
        cudaStream_t stream,
        float *one, float *zero);

size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

/// bwd filter

99
bool can_backward_filter_proxy_convnet(size_t N,
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

void backward_filter_proxy_convnet(const float *src,
        const float *diff,
        float *grad,
        float *workspace,
        size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs, // IN stride and ON stride
        size_t PH, size_t PW,
        size_t SH, size_t SW,
        cublasHandle_t cublas_handle,
        cudaStream_t stream,
        float *one, float *zero);

size_t get_workspace_in_floats_backward_filter_proxy_convnet(size_t N,
        size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW,
        size_t FH, size_t FW,
        size_t INs, size_t ONs,
        size_t PH, size_t PW,
        size_t SH, size_t SW);

} // namespace local
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen