opr_impl.cpp 3.0 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cambricon/checksum/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/cambricon/checksum/opr_impl.h"
M
Megvii Engine Team 已提交
13
#include "src/cambricon/checksum/checksum.mlu.h"
14 15 16 17 18 19 20 21 22

#include "src/cambricon/utils.h"

#include <algorithm>

using namespace megdnn;
using namespace cambricon;

namespace {
M
Megvii Engine Team 已提交
23 24 25
void bang_c_wrapper(
        uint32_t* dst, const uint32_t* src, int nr_elems, cnrtQueue_t queue,
        cnrtCoreVersion_t core_version) {
26 27 28 29 30 31 32 33 34 35 36
    cnrtKernelParamsBuffer_t params;
    cnrt_check(cnrtGetKernelParamsBuffer(&params));
    cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*)));
    cnrt_check(cnrtKernelParamsBufferAddParam(params, &src, sizeof(uint32_t*)));
    cnrt_check(cnrtKernelParamsBufferAddParam(params, &nr_elems, sizeof(int)));
    if (core_version == CNRT_MLU270) {
        cnrtDim3_t dim;
        dim.x = 16;
        dim.y = 1;
        dim.z = 1;
        cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4;
M
Megvii Engine Team 已提交
37 38
        cnrt_check(cnrtInvokeKernel_V2(
                (void*)&checksum_kernel_union4, dim, params, c, queue));
39 40 41 42 43 44
    } else if (core_version == CNRT_MLU220) {
        cnrtDim3_t dim;
        dim.x = 4;
        dim.y = 1;
        dim.z = 1;
        cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1;
M
Megvii Engine Team 已提交
45 46
        cnrt_check(cnrtInvokeKernel_V2(
                (void*)&checksum_kernel_union1, dim, params, c, queue));
47 48 49 50 51 52 53 54 55 56 57
    }
    after_kernel_launch();
    cnrt_check(cnrtDestroyKernelParamsBuffer(params));
}
}  // namespace

size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data */) {
    size_t ws_size = sizeof(ChecksumForward::Result::checksum);
    return ws_size;
}

M
Megvii Engine Team 已提交
58 59
ChecksumForward::Result ChecksumForwardImpl::exec(
        _megdnn_tensor_in data, _megdnn_workspace workspace) {
60 61 62 63 64
    Result result;
    memset(&result, 0, sizeof(result));
    check_exec(data.layout, workspace.size);
    auto queue = cnrt_queue(handle());

65
    auto ptr = static_cast<uint8_t*>(data.raw_ptr());
M
Megvii Engine Team 已提交
66
    size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t);
67
    auto last_val_size = std::min<size_t>(size_all, 4);
M
Megvii Engine Team 已提交
68 69 70
    cnrt_check(cnrtMemcpyAsync(
            &result.last_val, ptr + size_all - last_val_size, last_val_size, queue,
            CNRT_MEM_TRANS_DIR_DEV2HOST));
71 72
    if (size_ints) {
        auto&& device_info = current_device_info();
M
Megvii Engine Team 已提交
73 74
        bang_c_wrapper(
                reinterpret_cast<uint32_t*>(workspace.raw_ptr),
75
                static_cast<uint32_t*>(data.raw_ptr()), size_ints, queue,
M
Megvii Engine Team 已提交
76 77 78 79
                device_info.core_version);
        cnrt_check(cnrtMemcpyAsync(
                &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue,
                CNRT_MEM_TRANS_DIR_DEV2HOST));
80 81 82 83 84 85
    }
    cnrt_check(cnrtSyncQueue(queue));
    return result;
}

// vim: syntax=cpp.doxygen