/** * \file dnn/src/cambricon/checksum/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cambricon/checksum/opr_impl.h" #include "src/cambricon/checksum/checksum.mlu.h" #include "src/cambricon/utils.h" #include using namespace megdnn; using namespace cambricon; namespace { void bang_c_wrapper( uint32_t* dst, const uint32_t* src, int nr_elems, cnrtQueue_t queue, cnrtCoreVersion_t core_version) { cnrtKernelParamsBuffer_t params; cnrt_check(cnrtGetKernelParamsBuffer(¶ms)); cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*))); cnrt_check(cnrtKernelParamsBufferAddParam(params, &src, sizeof(uint32_t*))); cnrt_check(cnrtKernelParamsBufferAddParam(params, &nr_elems, sizeof(int))); if (core_version == CNRT_MLU270) { cnrtDim3_t dim; dim.x = 16; dim.y = 1; dim.z = 1; cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4; cnrt_check(cnrtInvokeKernel_V2( (void*)&checksum_kernel_union4, dim, params, c, queue)); } else if (core_version == CNRT_MLU220) { cnrtDim3_t dim; dim.x = 4; dim.y = 1; dim.z = 1; cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1; cnrt_check(cnrtInvokeKernel_V2( (void*)&checksum_kernel_union1, dim, params, c, queue)); } after_kernel_launch(); cnrt_check(cnrtDestroyKernelParamsBuffer(params)); } } // namespace size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data */) { size_t ws_size = sizeof(ChecksumForward::Result::checksum); return ws_size; } ChecksumForward::Result ChecksumForwardImpl::exec( _megdnn_tensor_in data, _megdnn_workspace workspace) { Result result; memset(&result, 0, sizeof(result)); check_exec(data.layout, workspace.size); auto queue = cnrt_queue(handle()); auto ptr = static_cast(data.raw_ptr()); size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); auto last_val_size = std::min(size_all, 4); cnrt_check(cnrtMemcpyAsync( &result.last_val, ptr + size_all - last_val_size, last_val_size, queue, CNRT_MEM_TRANS_DIR_DEV2HOST)); if (size_ints) { auto&& device_info = current_device_info(); bang_c_wrapper( reinterpret_cast(workspace.raw_ptr), static_cast(data.raw_ptr()), size_ints, queue, device_info.core_version); cnrt_check(cnrtMemcpyAsync( &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, CNRT_MEM_TRANS_DIR_DEV2HOST)); } cnrt_check(cnrtSyncQueue(queue)); return result; } // vim: syntax=cpp.doxygen