kern.cuh 2.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/chanwise/kern.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "src/cuda/utils.cuh"

#include <cuda_runtime.h>
M
Megvii Engine Team 已提交
16
#include <stdint.h>
17 18 19 20 21 22 23 24 25 26

#if MEGDNN_CC_HOST
#include "src/cuda/convolution/helper.h"
#endif

namespace megdnn {
namespace cuda {
namespace convolution {
namespace chanwise {

M
Megvii Engine Team 已提交
27 28 29
struct Param {
    uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h,
            pad_w, stride_h, stride_w, dilation_h, dilation_w;
30
    bool is_compute_deafult;
31
#if MEGDNN_CC_HOST
32 33
    static Param from_fwd_args(
            const ForwardSizeArgs& args, bool is_compute_deafult_ = true) {
34
#define U(v) static_cast<uint32_t>(v)
M
Megvii Engine Team 已提交
35 36 37 38 39 40 41 42 43 44
        auto&& src = args.src_layout->shape;
        auto&& dst = args.dst_layout->shape;
        auto&& fm = args.filter_meta;
        size_t c_pos, hw_pos;
        if (fm.format == param::Convolution::Format::NCHW) {
            c_pos = 1;
            hw_pos = 2;
        } else {
            c_pos = 3;
            hw_pos = 1;
45
        }
M
Megvii Engine Team 已提交
46
        return {
47 48 49 50 51 52
                U(src[0]),           U(src[c_pos]),     U(src[hw_pos]),
                U(src[hw_pos + 1]),  U(fm.ocpg),        U(fm.spatial[0]),
                U(fm.spatial[1]),    U(dst[hw_pos]),    U(dst[hw_pos + 1]),
                U(fm.padding[0]),    U(fm.padding[1]),  U(fm.stride[0]),
                U(fm.stride[1]),     U(fm.dilation[0]), U(fm.dilation[1]),
                is_compute_deafult_,
M
Megvii Engine Team 已提交
53 54 55
        };
#undef U
    }
56
#endif
M
Megvii Engine Team 已提交
57
};
58

M
Megvii Engine Team 已提交
59 60 61 62
template <typename T>
void run_bwd_data_small(
        T* src_grad, const T* dst_grad, const T* flt, const Param& param,
        cudaStream_t stream);
63

M
Megvii Engine Team 已提交
64 65 66 67
template <typename T>
void run_bwd_data(
        T* src_grad, const T* dst_grad, const T* flt, const Param& param,
        cudaStream_t stream);
68

69 70 71 72
template <typename T>
void run_bwd_depthwise_large_filter(
        T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);

M
Megvii Engine Team 已提交
73 74 75 76
template <typename T>
void run_bwd_filter(
        T* filter_grad, const T* src, const T* dst_grad, const Param& param,
        cudaStream_t stream);
77

M
Megvii Engine Team 已提交
78 79 80 81
}  // namespace chanwise
}  // namespace convolution
}  // namespace cuda
}  // namespace megdnn
82 83

// vim: ft=cpp syntax=cpp.doxygen