backward_data.cpp 6.0 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/cuda/warp_perspective/backward_data.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20
 */
#include "src/cuda/warp_perspective/opr_impl.h"

#include "src/cuda/utils.h"
#include "src/cuda/warp_perspective/common.h"
#include "src/cuda/warp_perspective/helper.h"

namespace megdnn {
namespace cuda {

21
WorkspaceBundle WarpPerspectiveBackwardDataImpl::get_workspace_bundle(
22 23
        void* ptr, const TensorLayout& mat, const TensorLayout& mat_idx,
        const TensorLayout& diff, const TensorLayout& grad) const {
24 25 26 27 28 29 30 31 32 33 34 35 36
    SmallVector<size_t> sizes;
    TensorLayout fmat = mat;
    TensorLayout fdiff = diff;
    TensorLayout fgrad = grad;
    auto get_workspace = [&sizes](TensorLayout& layout) {
        if (layout.dtype == dtype::BFloat16()) {
            layout.dtype = dtype::Float32();
            sizes.push_back(layout.span().dist_byte());
        }
    };
    get_workspace(fmat);
    get_workspace(fdiff);
    get_workspace(fgrad);
37 38
    sizes.push_back(
            get_float32_workspace_in_bytes(fmat, mat_idx, fdiff, fgrad));
39 40
    return {ptr, std::move(sizes)};
}
41

42
void WarpPerspectiveBackwardDataImpl::exec(_megdnn_tensor_in smat,
43
                                           _megdnn_tensor_in mat_idx,
44 45 46
                                           _megdnn_tensor_in sdiff,
                                           _megdnn_tensor_out sgrad,
                                           _megdnn_workspace sworkspace) {
47 48
    check_exec(smat.layout, mat_idx.layout, sdiff.layout, sgrad.layout,
               sworkspace.size);
49 50 51
    TensorND mat = smat;
    TensorND diff = sdiff;
    TensorND grad = sgrad;
52 53 54
    auto bundle =
            get_workspace_bundle(sworkspace.raw_ptr, smat.layout,
                                 mat_idx.layout, sdiff.layout, sgrad.layout);
55 56 57 58 59 60 61 62 63 64 65 66 67
    auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
            concrete_handle(this->handle()), &bundle);
    if (sgrad.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
        ctypecvt.src_to_comp_type(smat, mat)
                .src_to_comp_type(sdiff, diff)
                .src_to_comp_type(sgrad, grad);
    }
    {
        auto workspace = ctypecvt.workspace();
        auto stream = cuda_stream(this->handle());
        auto N = grad.layout.shape[0], C = grad.layout.shape[1],
             IH = grad.layout.shape[2], IW = grad.layout.shape[3],
             OH = diff.layout.shape[2], OW = diff.layout.shape[3];
68 69 70 71 72 73 74 75 76
        int* midx_ptr = nullptr;
        if (mat_idx.raw_ptr) {
            megdnn_assert(mat_idx.layout.ndim == 1);
            N = mat_idx.layout.shape[0];
            midx_ptr = mat_idx.ptr<int>();
        } else {
            megdnn_assert(mat_idx.layout.ndim == 0);
        }

77 78 79 80 81 82
        auto bval = param().border_val;
        auto bmode = warp_perspective::get_bmode(param().bmode);

        size_t batch_x_channel_size = N * C;
        size_t max_batch_x_channel = max_batch_x_channel_size();
        if (batch_x_channel_size <= max_batch_x_channel) {
83
            warp_perspective::backward_data_proxy(
84
                    mat.ptr<dt_float32>(), midx_ptr, diff.ptr<dt_float32>(),
85
                    grad.ptr<dt_float32>(),
86 87 88
                    reinterpret_cast<float*>(workspace.raw_ptr), N,
                    grad.layout.shape[0], C, IH, IW, OH, OW, bval, bmode,
                    stream);
89 90 91 92 93 94 95 96 97
        } else {
            dt_float32* mat_ptr = mat.ptr<dt_float32>();
            dt_float32* diff_ptr = diff.ptr<dt_float32>();
            dt_float32* grad_ptr = grad.ptr<dt_float32>();
            size_t max_batch_size = max_batch_x_channel / C;
            while (N > 0) {
                size_t curr_batch_size =
                        N > max_batch_size ? max_batch_size : N;
                warp_perspective::backward_data_proxy(
98
                        mat_ptr, midx_ptr, diff_ptr, grad_ptr,
99
                        reinterpret_cast<float*>(workspace.raw_ptr),
100 101
                        curr_batch_size, grad.layout.shape[0], C, IH, IW, OH,
                        OW, bval, bmode, stream);
102

103 104 105 106 107 108
                if (N <= max_batch_size) {
                    break;
                } else {
                    N -= max_batch_size;
                    mat_ptr += curr_batch_size * mat.layout.stride[0];
                    diff_ptr += curr_batch_size * diff.layout.stride[0];
109 110 111 112 113
                    if (midx_ptr == nullptr) {
                        grad_ptr += curr_batch_size * grad.layout.stride[0];
                    } else {
                        midx_ptr += curr_batch_size;
                    }
114
                }
115 116 117
            }
        }
    }
118 119 120
    if (sgrad.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
        ctypecvt.comp_to_dst_type(grad, sgrad);
    }
121 122
}

123
size_t WarpPerspectiveBackwardDataImpl::get_float32_workspace_in_bytes(
124 125
        const TensorLayout& /* mat */, const TensorLayout& mat_idx,
        const TensorLayout& diff, const TensorLayout& grad) const {
126 127
    auto N = grad.shape[0], C = grad.shape[1], IH = grad.shape[2],
         IW = grad.shape[3];
128 129 130 131 132
    auto OH = diff.shape[2], OW = diff.shape[3];
    auto bmode = warp_perspective::get_bmode(param().bmode);

    size_t max_batch_size = N;
    size_t max_batch_x_channel = max_batch_x_channel_size();
133
    if (N * C > max_batch_x_channel) {
134 135 136
        /* when batch size is too large, the workspace only contains part of grad,
           this will cause out of range with mat idx */
        megdnn_assert(mat_idx.ndim == 0, "batch size is too large, it's unsupported with mat idx backward.");
137 138 139 140 141 142 143 144
        max_batch_size = max_batch_x_channel / C;
    }

    auto res = warp_perspective::get_backward_data_workspace_in_bytes(
            max_batch_size, C, IH, IW, OH, OW, bmode);
    return res;
}

145 146
}  // namespace cuda
}  // namespace megdnn
147 148

// vim: syntax=cpp.doxygen