opr_impl.cpp 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
/**
 * \file dnn/src/cuda/relayout_format/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/handle.h"

using namespace megdnn;
using namespace cuda;

void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
                              _megdnn_workspace /* workspace */) {
    auto src_dtype = src.layout.dtype;
    megdnn_assert(
            param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
                    param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4,
            "relayout format of cuda only support NCHW4->CHWN4 or "
            "CHWN4->NCHW4");
    if (src_dtype.enumv() == DTypeEnum::QuantizedS8) {
        size_t row = 0, col = 0;
        if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) {
            row = src.layout[0],
            col = src.layout[1] * src.layout[2] * src.layout[3];
        } else {
            row = src.layout[0] * src.layout[1] * src.layout[2],
            col = src.layout[3];
        }
        TensorND trans_in, trans_out;
        trans_in.raw_ptr = src.raw_ptr;
        trans_in.layout = {{row, col}, dtype::Int32()};
        trans_in.layout.init_contiguous_stride();
        trans_out.raw_ptr = dst.raw_ptr;
        trans_out.layout = trans_in.layout;
        trans_out.layout.stride[0] = 1;
        trans_out.layout.stride[1] = row;
        return handle()->create_operator<RelayoutForward>()->exec(trans_in,
                                                                  trans_out);
    }
    TensorLayout exec_src, exec_dst;
    deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
    TensorND exec_src_nd{src.raw_ptr, exec_src};
    TensorND exec_dst_nd{dst.raw_ptr, exec_dst};
    handle()->create_operator<RelayoutForward>()->exec(exec_src_nd,
                                                       exec_dst_nd);
}

size_t RelayoutFormatImpl::get_workspace_in_bytes(
        const TensorLayout& /* src */, const TensorLayout& /* dst */) {
    return 0;
}

// vim: syntax=cpp.doxygen