chanwise.cpp 3.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/backward_data/chanwise.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22
 */

#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available(
        const SizeArgs& args) const {
23 24 25 26 27 28
    auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args());
    auto&& device_prop = cuda::current_device_prop();
    if (device_prop.sharedMemPerBlock <
        kparam.chl_mul * kparam.flt_h * kparam.flt_w * args.diff_layout->dtype.size()) {
        return false;
    }
M
Megvii Engine Team 已提交
29
    if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) {
30 31
        return false;
    }
32 33 34 35
    if ((args.diff_layout->dtype == args.filter_layout->dtype &&
         args.diff_layout->dtype == dtype::BFloat16()) ||
        (args.diff_layout->dtype == args.filter_layout->dtype &&
         args.diff_layout->dtype == dtype::QuantizedS8())) {
36 37
        return false;
    }
38 39 40 41 42 43 44 45 46 47 48 49
    auto&& fm = args.filter_meta;
    return args.filter_meta.format == Param::Format::NCHW &&
           args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
           fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
           fm.dilation[1] == 1 && !fm.should_flip;
}

size_t ConvolutionBackwardDataImpl::AlgoChanwise::get_workspace_in_bytes(
        const SizeArgs&) const {
    return 0;
}

M
Megvii Engine Team 已提交
50
void ConvolutionBackwardDataImpl::AlgoChanwise::exec(const ExecArgs& args) const {
51 52 53 54
    auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args());
    auto stream = cuda_stream(args.handle);
    switch (args.diff_layout->dtype.enumv()) {
        case DTypeEnum::Float32:
M
Megvii Engine Team 已提交
55 56 57
            return chanwise::run_bwd_data(
                    args.grad_tensor->ptr<float>(), args.diff_tensor->ptr<float>(),
                    args.filter_tensor->ptr<float>(), kparam, stream);
58 59 60 61 62 63 64

        case DTypeEnum::Float16:
#if CUDA_VERSION >= 9000
            if (is_compute_capability_required(5, 3)) {
                return chanwise::run_bwd_data(
                        static_cast<__half*>(args.grad_tensor->raw_ptr),
                        static_cast<__half*>(args.diff_tensor->raw_ptr),
M
Megvii Engine Team 已提交
65 66
                        static_cast<__half*>(args.filter_tensor->raw_ptr), kparam,
                        stream);
67 68 69 70 71 72 73
            } else {
                return chanwise::run_bwd_data(
                        args.grad_tensor->ptr<dt_float16>(),
                        args.diff_tensor->ptr<dt_float16>(),
                        args.filter_tensor->ptr<dt_float16>(), kparam, stream);
            }
#else
M
Megvii Engine Team 已提交
74 75 76 77
            return chanwise::run_bwd_data(
                    args.grad_tensor->ptr<dt_float16>(),
                    args.diff_tensor->ptr<dt_float16>(),
                    args.filter_tensor->ptr<dt_float16>(), kparam, stream);
78 79 80 81 82 83 84 85 86
#endif

        default:
            break;
    }
    megdnn_assert_internal(0);
}

// vim: syntax=cpp.doxygen