提交 3116e128 编写于 作者: M Megvii Engine Team

fix(ci/integration_test): fix benchmark torch version

GitOrigin-RevId: bd964ed505e8ca2a544db66a89dc109ff289aaca
上级 66f13586
此差异已折叠。
# This code is licensed under the MIT License. See the FindBANG.cmake script
# for the text of the license.
# The MIT License
#
# License for the specific language governing rights and limitations under
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
#######################################################################
# This converts a file written in makefile syntax into one that can be included
# by CMake.
file(READ ${input_file} depend_text)
if (NOT "${depend_text}" STREQUAL "")
# message("FOUND DEPENDS")
string(REPLACE "\\ " " " depend_text ${depend_text})
# This works for the cncc -M generated dependency files.
string(REGEX REPLACE "^.* : " "" depend_text ${depend_text})
string(REGEX REPLACE "[ \\\\]*\n" ";" depend_text ${depend_text})
set(dependency_list "")
foreach(file ${depend_text})
string(REGEX REPLACE "^ +" "" file ${file})
# OK, now if we had a UNC path, cncc has a tendency to only output the first '/'
# instead of '//'. Here we will test to see if the file exists, if it doesn't then
# try to prepend another '/' to the path and test again. If it still fails remove the
# path.
if(NOT EXISTS "${file}")
if (EXISTS "/${file}")
set(file "/${file}")
else()
message(WARNING " Removing non-existent dependency file: ${file}")
set(file "")
endif()
endif()
if(NOT IS_DIRECTORY "${file}")
# If softlinks start to matter, we should change this to REALPATH. For now we need
# to flatten paths, because cncc can generate stuff like /bin/../include instead of
# just /include.
get_filename_component(file_absolute "${file}" ABSOLUTE)
list(APPEND dependency_list "${file_absolute}")
endif()
endforeach()
else()
# message("FOUND NO DEPENDS")
endif()
# Remove the duplicate entries and sort them.
list(REMOVE_DUPLICATES dependency_list)
list(SORT dependency_list)
foreach(file ${dependency_list})
set(bang_cncc_depend "${bang_cncc_depend} \"${file}\"\n")
endforeach()
file(WRITE ${output_file} "# Generated by: make2cmake.cmake\nSET(BANG_CNCC_DEPEND\n ${bang_cncc_depend})\n\n")
# This code is licensed under the MIT License. See the FindBANG.cmake script
# for the text of the license.
# The MIT License
#
# License for the specific language governing rights and limitations under
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
#######################################################################
# Parses a .cnbin file produced by cncc and reports statistics about the file.
file(READ ${input_file} file_text)
if (NOT "${file_text}" STREQUAL "")
string(REPLACE ";" "\\;" file_text ${file_text})
string(REPLACE "\ncode" ";code" file_text ${file_text})
list(LENGTH file_text len)
foreach(line ${file_text})
# Only look at "code { }" blocks.
if(line MATCHES "^code")
# Break into individual lines.
string(REGEX REPLACE "\n" ";" line ${line})
foreach(entry ${line})
# Extract kernel names.
if (${entry} MATCHES "[^g]name = ([^ ]+)")
set(entry "${CMAKE_MATCH_1}")
# Check to see if the kernel name starts with "_"
set(skip FALSE)
# if (${entry} MATCHES "^_")
# Skip the rest of this block.
# message("Skipping ${entry}")
# set(skip TRUE)
# else ()
message("Kernel: ${entry}")
# endif ()
endif()
# Skip the rest of the block if necessary
if(NOT skip)
# Registers
if (${entry} MATCHES "reg([ ]+)=([ ]+)([^ ]+)")
set(entry "${CMAKE_MATCH_3}")
message("Registers: ${entry}")
endif()
# Local memory
if (${entry} MATCHES "lmem([ ]+)=([ ]+)([^ ]+)")
set(entry "${CMAKE_MATCH_3}")
message("Local: ${entry}")
endif()
# Shared memory
if (${entry} MATCHES "smem([ ]+)=([ ]+)([^ ]+)")
set(entry "${CMAKE_MATCH_3}")
message("Shared: ${entry}")
endif()
if (${entry} MATCHES "^}")
message("")
endif()
endif()
endforeach()
endif()
endforeach()
else()
# message("FOUND NO DEPENDS")
endif()
# This code is licensed under the MIT License. See the FindBANG.cmake script
# for the text of the license.
# The MIT License
#
# License for the specific language governing rights and limitations under
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
##########################################################################
# This file runs the cncc commands to produce the desired output file along with
# the dependency file needed by CMake to compute dependencies. In addition the
# file checks the output of each command and if the command fails it deletes the
# output files.
# Input variables
#
# verbose:BOOL=<> OFF: Be as quiet as possible (default)
# ON : Describe each step
#
# build_configuration:STRING=<> Typically one of Debug, MinSizeRel, Release, or
# RelWithDebInfo, but it should match one of the
# entries in BANG_HOST_FLAGS. This is the build
# configuration used when compiling the code. If
# blank or unspecified Debug is assumed as this is
# what CMake does.
#
# generated_file:STRING=<> File to generate. This argument must be passed in.
#
# generated_cnbin_file:STRING=<> File to generate. This argument must be passed
# in if build_cnbin is true.
if(NOT generated_file)
message(FATAL_ERROR "You must specify generated_file on the command line")
endif()
# Set these up as variables to make reading the generated file easier
set(CMAKE_COMMAND "@CMAKE_COMMAND@") # path
set(source_file "@source_file@") # path
set(CNCC_generated_dependency_file "@CNCC_generated_dependency_file@") # path
set(cmake_dependency_file "@cmake_dependency_file@") # path
set(BANG_make2cmake "@BANG_make2cmake@") # path
set(BANG_parse_cnbin "@BANG_parse_cnbin@") # path
set(build_cnbin @build_cnbin@) # bool
set(BANG_HOST_COMPILER "@BANG_HOST_COMPILER@") # path
# We won't actually use these variables for now, but we need to set this, in
# order to force this file to be run again if it changes.
set(generated_file_path "@generated_file_path@") # path
set(generated_file_internal "@generated_file@") # path
set(generated_cnbin_file_internal "@generated_cnbin_file@") # path
set(BANG_CNCC_EXECUTABLE "@BANG_CNCC_EXECUTABLE@") # path
set(BANG_CNCC_FLAGS @BANG_CNCC_FLAGS@ ;; @BANG_WRAP_OPTION_CNCC_FLAGS@) # list
@BANG_CNCC_FLAGS_CONFIG@
set(cncc_flags @cncc_flags@) # list
set(BANG_CNCC_INCLUDE_ARGS "@BANG_CNCC_INCLUDE_ARGS@") # list (needs to be in quotes to handle spaces properly).
set(format_flag "@format_flag@") # string
set(bang_language_flag @bang_language_flag@) # list
if(build_cnbin AND NOT generated_cnbin_file)
message(FATAL_ERROR "You must specify generated_cnbin_file on the command line")
endif()
# This is the list of host compilation flags. It C or CXX should already have
# been chosen by FindBANG.cmake.
@BANG_HOST_FLAGS@
# Take the compiler flags and package them up to be sent to the compiler via -Xcompiler
set(cncc_host_compiler_flags "")
# If we weren't given a build_configuration, use Debug.
if(NOT build_configuration)
set(build_configuration Debug)
endif()
string(TOUPPER "${build_configuration}" build_configuration)
#message("BANG_CNCC_HOST_COMPILER_FLAGS = ${BANG_CNCC_HOST_COMPILER_FLAGS}")
foreach(flag ${CMAKE_HOST_FLAGS} ${CMAKE_HOST_FLAGS_${build_configuration}})
# Extra quotes are added around each flag to help cncc parse out flags with spaces.
set(cncc_host_compiler_flags ${cncc_host_compiler_flags} ${flag})
endforeach()
# message("cncc_host_compiler_flags = ${cncc_host_compiler_flags}")
# Add the build specific configuration flags
list(APPEND BANG_CNCC_FLAGS ${BANG_CNCC_FLAGS_${build_configuration}})
# Remove the duplicated flags and including
list(REMOVE_DUPLICATES BANG_CNCC_FLAGS)
list(REMOVE_DUPLICATES BANG_CNCC_INCLUDE_ARGS)
# bang_execute_process - Executes a command with optional command echo and status message.
#
# status - Status message to print if verbose is true
# command - COMMAND argument from the usual execute_process argument structure
# ARGN - Remaining arguments are the command with arguments
#
# BANG_result - return value from running the command
#
# Make this a macro instead of a function, so that things like RESULT_VARIABLE
# and other return variables are present after executing the process.
macro(bang_execute_process status command)
set(_command ${command})
if(NOT "x${_command}" STREQUAL "xCOMMAND")
message(FATAL_ERROR "Malformed call to bang_execute_process. Missing COMMAND as second argument. (command = ${command})")
endif()
if(verbose)
execute_process(COMMAND "${CMAKE_COMMAND}" -E echo -- ${status})
# Now we need to build up our command string. We are accounting for quotes
# and spaces, anything else is left up to the user to fix if they want to
# copy and paste a runnable command line.
set(bang_execute_process_string)
foreach(arg ${ARGN})
# If there are quotes, excape them, so they come through.
string(REPLACE "\"" "\\\"" arg ${arg})
# Args with spaces need quotes around them to get them to be parsed as a single argument.
if(arg MATCHES " ")
list(APPEND bang_execute_process_string "\"${arg}\"")
else()
list(APPEND bang_execute_process_string ${arg})
endif()
endforeach()
# Echo the command
execute_process(COMMAND ${CMAKE_COMMAND} -E echo ${bang_execute_process_string})
endif()
# Run the command
execute_process(COMMAND ${ARGN} RESULT_VARIABLE BANG_result )
endmacro()
# Delete the target file
bang_execute_process(
"Removing ${generated_file}"
COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
)
# cncc ignore host flags
set(cncc_host_compiler_flags "")
# Generate the code
bang_execute_process(
"Generating ${generated_file}"
COMMAND "${BANG_CNCC_EXECUTABLE}"
"${source_file}"
${bang_language_flag}
${format_flag} -o "${generated_file}"
${cncc_flags}
${cncc_host_compiler_flags}
${BANG_CNCC_FLAGS}
-DCNCC
${BANG_CNCC_INCLUDE_ARGS}
)
if(BANG_result)
# Since cncc can sometimes leave half done files make sure that we delete the output file.
bang_execute_process(
"Removing ${generated_file}"
COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
)
message(FATAL_ERROR "Error generating file ${generated_file}")
else()
message(VERBOSE "Generated ${generated_file} successfully.")
endif()
# Cnbin resource report commands.
if( build_cnbin )
# Run with -cnbin to produce resource usage report.
bang_execute_process(
"Generating ${generated_cnbin_file}"
COMMAND "${BANG_CNCC_EXECUTABLE}"
"${source_file}"
${BANG_CNCC_FLAGS}
${cncc_flags}
${cncc_host_compiler_flags}
-DCNCC
-cnbin
-o "${generated_cnbin_file}"
${BANG_CNCC_INCLUDE_ARGS}
)
# Execute the parser script.
bang_execute_process(
"Executing the parser script"
COMMAND "${CMAKE_COMMAND}"
-D "input_file:STRING=${generated_cnbin_file}"
-P "${BANG_parse_cnbin}"
)
endif()
/**
* \file include/megcore_cambricon.h
*
* This file is part of MegDNN, a deep neural network run-time library
* developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*/
#pragma once
#include "megcore.h"
#include <cndev.h>
#include <cnml.h>
#include <cnrt.h>
#include "megdnn/internal/visibility_prologue.h"
namespace megcore {
megcoreStatus_t createDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized);
struct CambriconContext {
cnrtQueue_t queue = nullptr;
CambriconContext() = default;
CambriconContext(cnrtQueue_t q) : queue{q} {}
};
megcoreStatus_t createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx);
megcoreStatus_t getCambriconContext(
megcoreComputingHandle_t handle, CambriconContext* ctx);
} // namespace megcore
static inline megcoreStatus_t megcoreCreateComputingHandleWithCNRTQueue(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, cnrtQueue_t queue) {
megcore::CambriconContext ctx{queue};
return megcore::createComputingHandleWithCambriconContext(
compHandle, devHandle, flags, ctx);
}
static inline megcoreStatus_t megcoreGetCNRTQueue(
megcoreComputingHandle_t handle, cnrtQueue_t* queue) {
megcore::CambriconContext ctx;
auto ret = megcore::getCambriconContext(handle, &ctx);
*queue = ctx.queue;
return ret;
}
#include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen
load("//brain/megbrain/dnn:flags.bzl", "megdnn_opts")
load("@megvii3//tools/build_rules:bangc.bzl", "bangc_library")
package(default_visibility = ["//brain/megbrain/dnn:__subpackages__"])
bangc_library(
name = "bangc_kernels",
srcs = glob([
"**/*.mlu",
]) + [
"//brain/megbrain/dnn:src/common/utils.cuh",
],
hdrs = glob([
"**/*.mlu.h",
]),
deps = [
"//brain/megbrain/dnn:public_headers",
"//brain/megbrain/sdk/build_config",
],
copts = megdnn_opts + [
"-Ibrain/megbrain/dnn",
],
)
filegroup(
name = "cambricon_backend_files",
srcs = glob([
"**/*.cpp",
"**/*.h",
"**/*.hpp",
]),
)
/**
* \file dnn/src/cambricon/checksum/checksum.mlu.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cambricon/utils.mlu.h"
#ifdef __cplusplus
extern "C" {
#endif
void checksum_kernel_union1(uint32_t* dst, const uint32_t* src, int num_elems);
void checksum_kernel_union4(uint32_t* dst, const uint32_t* src, int num_elems);
#ifdef __cplusplus
}
#endif
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/checksum_kernel_union1.mlu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "checksum.mlu.h"
#include "cnsccl.h"
#include "mlu.h"
#define CLUSTER_DIM 1
#define CORE_DIM 4
#define STRIDE 1024
__mlu_entry__ void checksum_kernel_union1(uint32_t* dst, uint32_t* src,
int nr_elems) {
__nram__ uint32_t sum = 0;
__nram__ uint32_t val[STRIDE];
const uint32_t TASK_DIM = CLUSTER_DIM * CORE_DIM;
__mlu_shared__ uint32_t partial_sum[TASK_DIM];
int task_stride = STRIDE;
int start_offset = taskId * task_stride;
int global_stride = taskDim * task_stride;
for (int task_offset = start_offset; task_offset < nr_elems;
task_offset += global_stride) {
int end_offset = task_offset + task_stride;
end_offset = end_offset > nr_elems ? nr_elems : end_offset;
int copy_elems = end_offset - task_offset;
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t),
GDRAM2NRAM);
for (int i = 0; i < copy_elems; i++) {
sum = sum + val[i] * (task_offset + i + 1);
}
}
partial_sum[taskId] = sum;
__sync_cluster();
if (taskId == 0) {
uint32_t res = 0;
for (int i = 0; i < taskDim; i++) {
res += partial_sum[i];
}
dst[0] = res;
}
}
#undef CLUSTER_DIM
#undef CORE_DIM
#undef STRIDE
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/checksum_kernel_union4.mlu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "checksum.mlu.h"
#include "cnsccl.h"
#include "mlu.h"
#define CLUSTER_DIM 4
#define CORE_DIM 4
#define STRIDE 1024
__mlu_entry__ void checksum_kernel_union4(uint32_t* dst, uint32_t* src,
int nr_elems) {
__nram__ uint32_t sum = 0;
__nram__ uint32_t val[STRIDE];
__mlu_shared__ uint32_t partial_sum_send[CORE_DIM];
__mlu_shared__ uint32_t partial_sum_recv[CLUSTER_DIM];
int task_stride = STRIDE;
int start_offset = taskId * task_stride;
int global_stride = taskDim * task_stride;
for (int task_offset = start_offset; task_offset < nr_elems;
task_offset += global_stride) {
int end_offset = task_offset + task_stride;
end_offset = end_offset > nr_elems ? nr_elems : end_offset;
int copy_elems = end_offset - task_offset;
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t),
GDRAM2NRAM);
for (int i = 0; i < copy_elems; i++) {
sum = sum + val[i] * (task_offset + i + 1);
}
}
partial_sum_send[coreId] = sum;
__sync_cluster();
if (coreId == 0) {
for (int i = 1; i < CORE_DIM; ++i) {
partial_sum_send[0] += partial_sum_send[i];
}
}
__sync_all();
cnscclGather((void*)&partial_sum_send, (void*)&partial_sum_recv, 1,
cnscclInt, 0);
__sync_all();
if (clusterId == 0 && coreId == 0) {
uint32_t res = 0;
for (int i = 0; i < CLUSTER_DIM; ++i) {
res += partial_sum_recv[i];
}
dst[0] = res;
}
}
#undef CLUSTER_DIM
#undef CORE_DIM
#undef STRIDE
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cambricon/checksum/opr_impl.h"
#include "src/cambricon/checksum/checksum.mlu.h"
#include "src/cambricon/utils.h"
#include <algorithm>
using namespace megdnn;
using namespace cambricon;
namespace {
void bang_c_wrapper(
uint32_t* dst, const uint32_t* src, int nr_elems, cnrtQueue_t queue,
cnrtCoreVersion_t core_version) {
cnrtKernelParamsBuffer_t params;
cnrt_check(cnrtGetKernelParamsBuffer(&params));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*)));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &src, sizeof(uint32_t*)));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &nr_elems, sizeof(int)));
if (core_version == CNRT_MLU270) {
cnrtDim3_t dim;
dim.x = 16;
dim.y = 1;
dim.z = 1;
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4;
cnrt_check(cnrtInvokeKernel_V2(
(void*)&checksum_kernel_union4, dim, params, c, queue));
} else if (core_version == CNRT_MLU220) {
cnrtDim3_t dim;
dim.x = 4;
dim.y = 1;
dim.z = 1;
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1;
cnrt_check(cnrtInvokeKernel_V2(
(void*)&checksum_kernel_union1, dim, params, c, queue));
}
after_kernel_launch();
cnrt_check(cnrtDestroyKernelParamsBuffer(params));
}
} // namespace
size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data */) {
size_t ws_size = sizeof(ChecksumForward::Result::checksum);
return ws_size;
}
ChecksumForward::Result ChecksumForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_workspace workspace) {
Result result;
memset(&result, 0, sizeof(result));
check_exec(data.layout, workspace.size);
auto queue = cnrt_queue(handle());
auto ptr = static_cast<uint8_t*>(data.raw_ptr());
size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t);
auto last_val_size = std::min<size_t>(size_all, 4);
cnrt_check(cnrtMemcpyAsync(
&result.last_val, ptr + size_all - last_val_size, last_val_size, queue,
CNRT_MEM_TRANS_DIR_DEV2HOST));
if (size_ints) {
auto&& device_info = current_device_info();
bang_c_wrapper(
reinterpret_cast<uint32_t*>(workspace.raw_ptr),
static_cast<uint32_t*>(data.raw_ptr()), size_ints, queue,
device_info.core_version);
cnrt_check(cnrtMemcpyAsync(
&result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue,
CNRT_MEM_TRANS_DIR_DEV2HOST));
}
cnrt_check(cnrtSyncQueue(queue));
return result;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cambricon/utils.h"
namespace megdnn {
namespace cambricon {
class ChecksumForwardImpl final : public ChecksumForward {
public:
using ChecksumForward::ChecksumForward;
size_t get_workspace_in_bytes(const TensorLayout&) override;
bool is_thread_safe() const override { return true; }
Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override;
};
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/handle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/handle_impl.h"
#include "src/common/version_symbol.h"
#include "src/cambricon/handle.h"
#include "src/cambricon/utils.h"
#include <cnrt.h>
#include "src/cambricon/checksum/opr_impl.h"
namespace megdnn {
namespace cambricon {
HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
: HandleImplHelper(comp_handle, HandleType::CAMBRICON) {
// Get megcore device handle
megcoreDeviceHandle_t dev_handle;
megcoreGetDeviceHandle(comp_handle, &dev_handle);
int dev_id;
megcoreGetDeviceID(dev_handle, &dev_id);
unsigned int dev_num;
cnrt_check(cnrtGetDeviceCount(&dev_num));
MEGDNN_MARK_USED_VAR(dev_num);
// check validity of device_id
megdnn_assert(dev_id >= 0 && static_cast<unsigned int>(dev_id) < dev_num);
m_device_id = dev_id;
cnrt_check(cnrtGetDeviceInfo(&m_device_info, dev_id));
megcore::getCambriconContext(comp_handle, &m_megcore_context);
}
HandleImpl::~HandleImpl() noexcept = default;
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
megdnn_throw("unsupported cambricon opr");
return nullptr;
}
size_t HandleImpl::alignment_requirement() const {
return 1;
}
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop
} // namespace cambricon
} // namespace megdnn
MEGDNN_VERSION_SYMBOL3(
CNRT, CNRT_MAJOR_VERSION, CNRT_MINOR_VERSION, CNRT_PATCH_VERSION);
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/handle.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cambricon.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"
#include "src/common/handle_impl.h"
#include "src/common/utils.h"
#include <atomic>
#include <mutex>
#include <cnrt.h>
namespace megdnn {
namespace cambricon {
class HandleImpl : public HandleImplHelper {
public:
HandleImpl(megcoreComputingHandle_t computing_handle);
~HandleImpl() noexcept;
size_t alignment_requirement() const override;
const cnrtDeviceInfo_t& device_info() const { return m_device_info; }
template <typename Opr>
std::unique_ptr<Opr> create_operator();
const megcore::CambriconContext& megcore_context() const {
return m_megcore_context;
}
int device_id() const { return m_device_id; }
cnrtQueue_t queue() const { return megcore_context().queue; }
//! global matmul opr
Checksum* checksum_opr() override final {
return get_helper_opr<Checksum, 0>(this);
}
private:
int m_device_id;
//! MegDNN handle does not manage the lifetime of cnrt queue.
megcore::CambriconContext m_megcore_context;
cnrtDeviceInfo_t m_device_info;
};
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_computing_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore.h"
#include "src/cambricon/utils.h"
#include "src/common/utils.h"
#include "src/cambricon/megcore/cambricon_computing_context.hpp"
using namespace megcore;
using namespace megcore::cambricon;
CambriconComputingContext::CambriconComputingContext(
megcoreDeviceHandle_t dev_handle, unsigned int flags,
const CambriconContext& ctx)
: ComputingContext(dev_handle, flags),
own_queue{ctx.queue == nullptr},
context_{ctx} {
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformCambricon);
if (own_queue) {
cnrt_check(cnrtCreateQueue(&context_.queue));
}
}
CambriconComputingContext::~CambriconComputingContext() {
if (own_queue) {
cnrt_check(cnrtDestroyQueue(context_.queue));
}
}
void CambriconComputingContext::memcpy(
void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) {
cnrtMemTransDir_t dir;
switch (kind) {
case megcoreMemcpyDeviceToHost:
dir = CNRT_MEM_TRANS_DIR_DEV2HOST;
break;
case megcoreMemcpyHostToDevice:
dir = CNRT_MEM_TRANS_DIR_HOST2DEV;
break;
case megcoreMemcpyDeviceToDevice:
dir = CNRT_MEM_TRANS_DIR_DEV2DEV;
break;
default:
megdnn_throw("bad cnrt mem trans dir");
}
if (kind == megcoreMemcpyDeviceToDevice) {
cnrt_check(cnrtSyncQueue(context_.queue));
cnrt_check(cnrtMemcpy(dst, const_cast<void*>(src), size_in_bytes, dir));
return;
}
cnrt_check(cnrtMemcpyAsync(
dst, const_cast<void*>(src), size_in_bytes, context_.queue, dir));
}
void CambriconComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
cnrt_check(cnrtSyncQueue(context_.queue));
cnrt_check(cnrtMemset(dst, value, size_in_bytes));
}
void CambriconComputingContext::synchronize() {
cnrt_check(cnrtSyncQueue(context_.queue));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_computing_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cambricon.h"
#include "src/common/megcore/common/computing_context.hpp"
namespace megcore {
namespace cambricon {
class CambriconComputingContext final : public ComputingContext {
public:
CambriconComputingContext(megcoreDeviceHandle_t dev_handle,
unsigned int flags,
const CambriconContext& ctx = {});
~CambriconComputingContext();
void memcpy(void* dst, const void* src, size_t size_in_bytes,
megcoreMemcpyKind_t kind) override;
void memset(void* dst, int value, size_t size_in_bytes) override;
void synchronize() override;
const CambriconContext& context() const { return context_; }
cnrtQueue_t queue() const { return context().queue; }
private:
bool own_queue;
CambriconContext context_;
};
} // namespace cambricon
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_device_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore.h"
#include "src/cambricon/utils.h"
#include "src/common/utils.h"
#include "src/cambricon/megcore/cambricon_device_context.hpp"
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define CNRT_VERSION_STR \
STR(CNRT_MAJOR_VERSION) \
"." STR(CNRT_MINOR_VERSION) "." STR(CNRT_PATCH_VERSION)
#pragma message "compile with cnrt " CNRT_VERSION_STR " "
#undef STR_HELPER
#undef STR
using namespace megcore;
using namespace cambricon;
CambriconDeviceContext::CambriconDeviceContext(
int device_id, unsigned int flags, bool global_initialized)
: DeviceContext(megcorePlatformCambricon, device_id, flags) {
if (!global_initialized)
init_status.init();
unsigned int version;
cnrt_check(cnrtGetVersion(&version));
megdnn_assert(
version == CNRT_VERSION, "megcore compiled with cnrt %d, get %d at runtime",
CNRT_VERSION, version);
unsigned int dev_num;
cnrt_check(cnrtGetDeviceCount(&dev_num));
MEGDNN_MARK_USED_VAR(dev_num);
// check validity of device_id
megdnn_assert(device_id >= 0 && static_cast<unsigned int>(device_id) < dev_num);
cnrt_check(cnrtGetDeviceInfo(&device_info, device_id));
}
CambriconDeviceContext::~CambriconDeviceContext() noexcept = default;
size_t CambriconDeviceContext::mem_alignment_in_bytes() const noexcept {
return 1;
}
void CambriconDeviceContext::activate() {
int id = device_id();
cnrtDev_t dev;
cnrt_check(cnrtGetDeviceHandle(&dev, id));
cnrt_check(cnrtSetCurrentDevice(dev));
}
void* CambriconDeviceContext::malloc(size_t size_in_bytes) {
void* ptr;
cnrt_check(cnrtMalloc(&ptr, size_in_bytes));
return ptr;
}
void CambriconDeviceContext::free(void* ptr) {
cnrt_check(cnrtFree(ptr));
}
CambriconDeviceContext::InitStatus CambriconDeviceContext::init_status;
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_device_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "megcore_cambricon.h"
#include "src/common/megcore/common/device_context.hpp"
#include "src/common/utils.h"
namespace megcore {
namespace cambricon {
class CambriconDeviceContext : public DeviceContext {
public:
CambriconDeviceContext(int device_id, unsigned int flags,
bool global_initialized = false);
~CambriconDeviceContext() noexcept;
size_t mem_alignment_in_bytes() const noexcept override;
void activate() override;
void* malloc(size_t size_in_bytes) override;
void free(void* ptr) override;
struct InitStatus {
bool initialized;
std::mutex mtx;
InitStatus() : initialized{false} {}
void init() {
std::lock_guard<std::mutex> guard{mtx};
if (!initialized) {
auto cnrt_err = cnrtInit(0);
initialized = cnrt_err == CNRT_RET_SUCCESS;
megdnn_assert(initialized, "cnrt initialize failed: (cnrt:%d)",
static_cast<int>(cnrt_err));
}
}
~InitStatus() {
if (initialized) {
cnrtDestroy();
initialized = false;
}
}
};
static InitStatus init_status;
private:
cnrtDeviceInfo_t device_info;
};
} // namespace cambricon
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/public_api/computing.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore_cambricon.h"
#include "src/cambricon/megcore/cambricon_computing_context.hpp"
#include "src/cambricon/megcore/cambricon_device_context.hpp"
#include "src/common/megcore/public_api/computing.hpp"
#include "src/common/megcore/public_api/device.hpp"
#include "src/common/utils.h"
using namespace megcore;
megcoreStatus_t megcore::createDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized) {
auto content = megdnn::make_unique<cambricon::CambriconDeviceContext>(
deviceID, flags, global_initialized);
auto& ctx = *devHandle;
ctx = new megcoreDeviceContext;
ctx->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx) {
auto content = megdnn::make_unique<cambricon::CambriconComputingContext>(
devHandle, flags, ctx);
auto& H = *compHandle;
H = new megcoreComputingContext;
H->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::getCambriconContext(
megcoreComputingHandle_t handle, CambriconContext* ctx) {
auto&& H = handle;
megdnn_assert(H);
megcoreDeviceHandle_t dev_handle = H->content->dev_handle();
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformCambricon);
auto context = static_cast<megcore::cambricon::CambriconComputingContext*>(
H->content.get());
*ctx = context->context();
return megcoreSuccess;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cambricon/utils.h"
#include "src/cambricon/utils.mlu.h"
#include "src/cambricon/handle.h"
#include "src/common/utils.h"
#include <mutex>
#include <unordered_map>
using namespace megdnn;
using namespace cambricon;
namespace {
struct DeviceInfoRecord {
bool init = false;
cnrtDeviceInfo_t device_info;
std::mutex mtx;
};
std::unordered_map<cnrtDev_t, int> dev2device_id;
std::mutex dev2device_id_mtx;
constexpr int MAX_NR_DEVICE = 64;
DeviceInfoRecord device_info_rec[MAX_NR_DEVICE];
} // namespace
void cambricon::__throw_cnrt_error__(cnrtRet_t err, const char* msg) {
auto s = ssprintf(
"cnrt return %s(%d) occurred; expr: %s", cnrtGetErrorStr(err), int(err),
msg);
megdnn_throw(s.c_str());
}
cnrtDeviceInfo_t cambricon::current_device_info() {
static bool dev2device_id_init = false;
{
std::lock_guard<std::mutex> lock(dev2device_id_mtx);
if (!dev2device_id_init) {
unsigned int dev_num = 0;
cnrt_check(cnrtGetDeviceCount(&dev_num));
for (unsigned int dev_id = 0; dev_id < dev_num; ++dev_id) {
cnrtDev_t dev;
cnrt_check(cnrtGetDeviceHandle(&dev, dev_id));
dev2device_id[dev] = dev_id;
}
dev2device_id_init = true;
}
}
cnrtDev_t dev;
cnrt_check(cnrtGetCurrentDevice(&dev));
{
std::lock_guard<std::mutex> lock(dev2device_id_mtx);
int dev_id = dev2device_id.at(dev);
auto& rec = device_info_rec[dev_id];
{
std::lock_guard<std::mutex> lock(rec.mtx);
if (!rec.init) {
cnrt_check(cnrtGetDeviceInfo(&rec.device_info, dev_id));
rec.init = true;
}
}
return rec.device_info;
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cdefs.h"
#include "megdnn/handle.h"
#include "src/cambricon/utils.mlu.h"
#include "src/common/utils.h"
#include "src/cambricon/handle.h"
#include <cnrt.h>
namespace megdnn {
namespace cambricon {
static inline HandleImpl* concrete_handle(Handle* handle) {
return static_cast<cambricon::HandleImpl*>(handle);
}
static inline cnrtQueue_t cnrt_queue(Handle* handle) {
return concrete_handle(handle)->queue();
}
//! get device info of current active device
cnrtDeviceInfo_t current_device_info();
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.mlu.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.cuh"
#include <stdint.h>
#include <cnrt.h>
#define cnrt_check(_x) \
do { \
cnrtRet_t _ret = (_x); \
if (_ret != CNRT_RET_SUCCESS) { \
::megdnn::cambricon::__throw_cnrt_error__(_ret, #_x); \
} \
} while (0)
#define after_kernel_launch() \
do { \
cnrt_check(cnrtGetLastErr()); \
} while (0)
namespace megdnn {
namespace cambricon {
//! Error handling funcions
MEGDNN_NORETURN void __throw_cnrt_error__(cnrtRet_t err, const char* msg);
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/checksum.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "test/cambricon/fixture.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(CAMBRICON, CHECKSUM_FORWARD) {
auto cambricon_opr = handle_cambricon()->create_operator<megdnn::Checksum>(),
naive_opr = handle_naive()->create_operator<megdnn::Checksum>();
std::mt19937 rng(std::random_device{}());
for (size_t size : {3, 8, 4 * 4 * 1024, 12345, 1024 * 1024, 1024 * 1024 * 10}) {
auto aligned_size = size + ((512 - size % 512) % 512);
auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) {
TensorND tensor;
tensor.reset_ptr(ptr);
tensor.layout.init_contiguous_stride({size});
tensor.layout.dtype = dtype::Byte();
WorkspaceWrapper workspace(
handle_cambricon(), opr->get_workspace_in_bytes(tensor.layout));
if (log_size) {
printf("checksum(%zu): workspace=%zu\n", size,
workspace.workspace().size);
}
return opr->exec(tensor, workspace.workspace());
};
std::vector<uint8_t> buf(aligned_size);
for (size_t i = 0; i < size; ++i)
buf[i] = 1;
auto run_offsset = [&](size_t offset) {
void* dev_ptr = megdnn_malloc(handle_cambricon(), buf.size() + offset);
void* dev_buf = static_cast<char*>(dev_ptr) + offset;
Checksum::Result res_cambricon[2], res_naive[2];
for (int change_last = 0; change_last < 2; ++change_last) {
if (change_last)
++buf[size - 1];
megdnn_memcpy_H2D(handle_cambricon(), dev_buf, buf.data(), size);
res_cambricon[change_last] =
run(cambricon_opr.get(), dev_buf, !change_last);
res_naive[change_last] = run(naive_opr.get(), buf.data(), false);
}
megdnn_free(handle_cambricon(), dev_ptr);
ASSERT_EQ(res_naive[0], res_cambricon[0]) << "failed for size " << size;
ASSERT_EQ(res_naive[1], res_cambricon[1]);
ASSERT_NE(res_cambricon[0], res_cambricon[1]);
};
for (size_t i = 0; i < 8; ++i) {
run_offsset(i);
}
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/fixture.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/cambricon/fixture.h"
#include "src/cambricon/handle.h"
#include "src/cambricon/utils.h"
#include "test/common/memory_manager.h"
#include "test/common/random_state.h"
#include "test/common/utils.h"
#include <cnrt.h>
#include <cstdlib>
using namespace megdnn;
using namespace test;
void CAMBRICON::SetUp() {
RandomState::reset();
megcoreDeviceHandle_t dev_handle;
// use card 0
megcore_check(megcoreCreateDeviceHandle(&dev_handle, megcorePlatformCambricon, 0));
megcoreComputingHandle_t comp_handle;
megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
m_handle_cambricon = Handle::make(comp_handle);
megdnn_assert(m_handle_cambricon);
}
Handle* CAMBRICON::handle_naive() {
if (!m_handle_naive)
m_handle_naive = create_cpu_handle(2);
return m_handle_naive.get();
}
void CAMBRICON::TearDown() {
m_handle_naive.reset();
m_handle_cambricon.reset();
MemoryManagerHolder::instance()->clear();
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/fixture.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <gtest/gtest.h>
#include "test/common/fix_gtest_on_platforms_without_exception.inl"
#include "megcore_cdefs.h"
#include "megdnn/handle.h"
#include <memory>
namespace megdnn {
namespace test {
class CAMBRICON : public ::testing::Test {
public:
void SetUp() override;
void TearDown() override;
Handle* handle_cambricon() { return m_handle_cambricon.get(); }
Handle* handle_naive();
private:
std::unique_ptr<Handle> m_handle_naive;
std::unique_ptr<Handle> m_handle_cambricon;
};
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册