提交 00ef6772 编写于 作者: M Megvii Engine Team

fix(mgb): remove internal for cambricon and atlas

GitOrigin-RevId: 861e349eb44b87a12ff12d7d1e7ac5331a6e73ef
上级 aeffcd58
......@@ -38,7 +38,9 @@ option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON)
option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON)
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF)
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON)
option(MGE_WITH_CAMBRICON "Build MegEngine with Cambricon support" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(MGE_WITH_ATLAS "Build MegEngine with Atlas support" OFF)
option(MGE_ENABLE_RTTI "Build with RTTI" ON)
option(MGE_ENABLE_LOGGING "Build with logging" ON)
option(MGE_DEBUG_UTIL "Enable debug utility" ON)
......@@ -406,6 +408,51 @@ if(MGE_WITH_CUDA)
set(MGE_CUDA_LIBS "${MGE_CUDA_LIBS}")
endif()
if(MGE_WITH_CAMBRICON)
include_directories("$ENV{NEUWARE_HOME}/include")
link_directories("$ENV{NEUWARE_HOME}/lib64")
include(cmake/FindBANG/FindBANG.cmake)
if (${MGE_MLU_ARCH} STREQUAL "MLU100")
set(BANG_ARCH "100")
elseif (${MGE_MLU_ARCH} STREQUAL "MLU1h8")
set(BANG_ARCH "110")
elseif (${MGE_MLU_ARCH} STREQUAL "MLU220")
set(BANG_ARCH "220")
elseif (${MGE_MLU_ARCH} STREQUAL "MLU270")
set(BANG_ARCH "270")
elseif (${MGE_MLU_ARCH} STREQUAL "MLU290")
set(BANG_ARCH "290")
elseif (${MGE_MLU_ARCH} STREQUAL "MLU200")
set(BANG_ARCH "200")
else()
message (FATAL_ERROR "Unsupported MLU arch.")
endif()
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} --bang-mlu-arch=${MGE_MLU_ARCH}")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -std=c++11 -Werror")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__BANG_ARCH__=${BANG_ARCH}")
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g -O0")
elseif (${CMAKE_BUILD_TYPE} STREQUAL "Release")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3")
elseif (${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g -O3")
elseif (${CMAKE_BUILD_TYPE} STREQUAL "MinSizeRel")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Os")
endif()
include(cmake/cnrt.cmake)
include(cmake/cndev.cmake)
include(cmake/cnml.cmake)
list(APPEND MGE_CAMBRICON_LIBS libcnrt libcndev libcnml)
set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}")
endif()
if(MGE_WITH_ATLAS)
include(cmake/aclrt.cmake)
list(APPEND MGE_ATLAS_LIBS libascendcl)
set(MGE_ATLAS_LIBS "${MGE_ATLAS_LIBS}")
set(MGB_ATLAS ${MGE_WITH_ATLAS})
endif()
find_program(CCACHE_BIN ccache)
if(CCACHE_BIN)
......@@ -494,6 +541,11 @@ set(MGB_CUDA ${MGE_WITH_CUDA})
set(MEGDNN_WITH_CUDA ${MGE_WITH_CUDA})
# CAMBRICON
set(MGB_CAMBRICON ${MGE_WITH_CAMBRICON})
set(MEGDNN_WITH_CAMBRICON ${MGE_WITH_CAMBRICON})
# Debug info
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug" OR ${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo")
set(MGB_ASSERT_LOC 1)
......
if($ENV{LIBRARY_PATH})
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH})
endif()
find_library(ACLRT_LIBRARY
NAMES libascendcl.so
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{ACLRT_HOME}/lib64/stub" ${CMAKE_INSTALL_PREFIX}
HINTS ${SYSTEM_LIBRARY_PATHS}
PATH_SUFFIXES stub
DOC "ACL library." )
if(ACLRT_LIBRARY STREQUAL "ACLRT_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find ACLRT Library")
endif()
get_filename_component(__found_aclrt_root "${ACLRT_LIBRARY}/../../../" REALPATH)
find_path(ACLRT_INCLUDE_DIR
NAMES acl/acl.h
HINTS "$ENV{ACLRT_HOME}/include" ${__found_aclrt_root}
PATH_SUFFIXES include
DOC "Path to ACLRT include directory." )
if(ACLRT_INCLUDE_DIR STREQUAL "ACLRT_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ACLRT Library")
endif()
add_library(libascendcl SHARED IMPORTED)
set_target_properties(libascendcl PROPERTIES
IMPORTED_LOCATION ${ACLRT_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${ACLRT_INCLUDE_DIR}
)
message("-- Found ACLRT: ${__found_aclrt_root}")
if($ENV{LIBRARY_PATH})
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH})
endif()
find_library(CNDEV_LIBRARY
NAMES libcndev.so
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX}
HINTS ${SYSTEM_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "CNDEV library." )
if(CNDEV_LIBRARY STREQUAL "CNDEV_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find CNDEV Library")
endif()
get_filename_component(__found_cndev_root "${CNDEV_LIBRARY}/../include" REALPATH)
find_path(CNDEV_INCLUDE_DIR
NAMES cndev.h
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cndev_root}
PATH_SUFFIXES include
DOC "Path to CNDEV include directory." )
if(CNDEV_INCLUDE_DIR STREQUAL "CNDEV_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find CNDEV Library")
endif()
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_1 REGEX "^#define CNDEV_VERSION_1 [0-9]+.*$")
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_2 REGEX "^#define CNDEV_VERSION_2 [0-9]+.*$")
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_3 REGEX "^#define CNDEV_VERSION_3 [0-9]+.*$")
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_4 REGEX "^#define CNDEV_VERSION_4 [0-9]+.*$")
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_5 REGEX "^#define CNDEV_VERSION_5 [0-9]+.*$")
string(REGEX REPLACE "^#define CNDEV_VERSION_1 ([0-9]+).*$" "\\1" CNDEV_VERSION_1 "${CNDEV_1}")
string(REGEX REPLACE "^#define CNDEV_VERSION_2 ([0-9]+).*$" "\\1" CNDEV_VERSION_2 "${CNDEV_2}")
string(REGEX REPLACE "^#define CNDEV_VERSION_3 ([0-9]+).*$" "\\1" CNDEV_VERSION_3 "${CNDEV_3}")
string(REGEX REPLACE "^#define CNDEV_VERSION_4 ([0-9]+).*$" "\\1" CNDEV_VERSION_4 "${CNDEV_4}")
string(REGEX REPLACE "^#define CNDEV_VERSION_5 ([0-9]+).*$" "\\1" CNDEV_VERSION_5 "${CNDEV_5}")
set(CNDEV_VERSION_STRING "${CNDEV_VERSION_1}.${CNDEV_VERSION_2}.${CNDEV_VERSION_3}.${CNDEV_VERSION_4}.${CNDEV_VERSION_5}")
add_library(libcndev SHARED IMPORTED)
set_target_properties(libcndev PROPERTIES
IMPORTED_LOCATION ${CNDEV_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${CNDEV_INCLUDE_DIR}
)
message("-- Found CNDEV: ${__found_cndev_root} (found version: ${CNDEV_VERSION_STRING})")
if($ENV{LIBRARY_PATH})
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH})
endif()
find_library(CNML_LIBRARY
NAMES libcnml.so
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX}
HINTS ${SYSTEM_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "CNML library." )
if(CNML_LIBRARY STREQUAL "CNML_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find CNML Library")
endif()
get_filename_component(__found_cnml_root "${CNML_LIBRARY}/../include" REALPATH)
find_path(CNML_INCLUDE_DIR
NAMES cnml.h
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cnml_root}
PATH_SUFFIXES include
DOC "Path to CNML include directory." )
if(CNML_INCLUDE_DIR STREQUAL "CNML_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find CNML Library")
endif()
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_MAJOR REGEX "^#define CNML_MAJOR_VERSION [0-9]+.*$")
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_MINOR REGEX "^#define CNML_MINOR_VERSION [0-9]+.*$")
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_PATCH REGEX "^#define CNML_PATCH_VERSION [0-9]+.*$")
string(REGEX REPLACE "^#define CNML_MAJOR_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_MAJOR "${CNML_MAJOR}")
string(REGEX REPLACE "^#define CNML_MINOR_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_MINOR "${CNML_MINOR}")
string(REGEX REPLACE "^#define CNML_PATCH_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_PATCH "${CNML_PATCH}")
set(CNML_VERSION_STRING "${CNML_VERSION_MAJOR}.${CNML_VERSION_MINOR}.${CNML_VERSION_PATCH}")
add_library(libcnml SHARED IMPORTED)
set_target_properties(libcnml PROPERTIES
IMPORTED_LOCATION ${CNML_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${CNML_INCLUDE_DIR}
)
message("-- Found CNML: ${__found_cnml_root} (found version: ${CNML_VERSION_STRING})")
if($ENV{LIBRARY_PATH})
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH})
endif()
find_library(CNRT_LIBRARY
NAMES libcnrt.so
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX}
HINTS ${SYSTEM_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "CNRT library." )
if(CNRT_LIBRARY STREQUAL "CNRT_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find CNRT Library")
endif()
get_filename_component(__found_cnrt_root "${CNRT_LIBRARY}/../include" REALPATH)
find_path(CNRT_INCLUDE_DIR
NAMES cnrt.h
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cnrt_root}
PATH_SUFFIXES include
DOC "Path to CNRT include directory." )
if(CNRT_INCLUDE_DIR STREQUAL "CNRT_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find CNRT Library")
endif()
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_MAJOR REGEX "^#define CNRT_MAJOR_VERSION [0-9]+.*$")
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_MINOR REGEX "^#define CNRT_MINOR_VERSION [0-9]+.*$")
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_PATCH REGEX "^#define CNRT_PATCH_VERSION [0-9]+.*$")
string(REGEX REPLACE "^#define CNRT_MAJOR_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_MAJOR "${CNRT_MAJOR}")
string(REGEX REPLACE "^#define CNRT_MINOR_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_MINOR "${CNRT_MINOR}")
string(REGEX REPLACE "^#define CNRT_PATCH_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_PATCH "${CNRT_PATCH}")
set(CNRT_VERSION_STRING "${CNRT_VERSION_MAJOR}.${CNRT_VERSION_MINOR}.${CNRT_VERSION_PATCH}")
add_library(libcnrt SHARED IMPORTED)
set_target_properties(libcnrt PROPERTIES
IMPORTED_LOCATION ${CNRT_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${CNRT_INCLUDE_DIR}
)
message("-- Found CNRT: ${__found_cnrt_root} (found version: ${CNRT_VERSION_STRING})")
/**
* \file dnn/include/megcore_atlas.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megcore.h"
#include <acl/acl.h>
#include "megdnn/internal/visibility_prologue.h"
namespace megcore {
megcoreStatus_t createAtlasDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized);
struct AtlasContext {
aclrtStream stream = nullptr;
AtlasContext() = default;
AtlasContext(aclrtStream s) : stream{s} {}
};
megcoreStatus_t createComputingHandleWithAtlasContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const AtlasContext& ctx);
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle,
AtlasContext* ctx);
namespace atlas {
//! convert acl error code to error string
const char* get_error_str(aclError error);
} // namespace atlas
} // namespace megcore
inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, aclrtStream stream) {
megcore::AtlasContext ctx{stream};
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle,
flags, ctx);
}
inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle,
aclrtStream* stream) {
megcore::AtlasContext ctx;
auto ret = megcore::getAtlasContext(handle, &ctx);
*stream = ctx.stream;
return ret;
}
#include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen
/**
* \file include/megcore_cambricon.h
*
* This file is part of MegDNN, a deep neural network run-time library
* developed by Megvii.
*
* \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*/
#pragma once
#include "megcore.h"
#include <cndev.h>
#include <cnml.h>
#include <cnrt.h>
#include "megdnn/internal/visibility_prologue.h"
namespace megcore {
megcoreStatus_t createDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized);
struct CambriconContext {
cnrtQueue_t queue = nullptr;
CambriconContext() = default;
CambriconContext(cnrtQueue_t q) : queue{q} {}
};
megcoreStatus_t createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx);
megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle,
CambriconContext* ctx);
} // namespace megcore
static inline megcoreStatus_t megcoreCreateComputingHandleWithCNRTQueue(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, cnrtQueue_t queue) {
megcore::CambriconContext ctx{queue};
return megcore::createComputingHandleWithCambriconContext(
compHandle, devHandle, flags, ctx);
}
static inline megcoreStatus_t megcoreGetCNRTQueue(
megcoreComputingHandle_t handle, cnrtQueue_t* queue) {
megcore::CambriconContext ctx;
auto ret = megcore::getCambriconContext(handle, &ctx);
*queue = ctx.queue;
return ret;
}
#include "megdnn/internal/visibility_epilogue.h"
// vim: syntax=cpp.doxygen
......@@ -19,6 +19,8 @@
typedef enum {
megcorePlatformCPU = 1,
megcorePlatformCUDA = 4,
megcorePlatformCambricon = 7,
megcorePlatformAtlas = 8,
} megcorePlatform_t;
/**
......
......@@ -33,6 +33,8 @@ class Handle {
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
ATLAS = 13,
CAMBRICON = 12,
};
protected:
......
......@@ -45,6 +45,24 @@ if(MGE_WITH_CUDA)
list(APPEND SOURCES ${CUSOURCES})
endif()
if(MGE_WITH_CAMBRICON)
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE BANG_SOURCES cambricon/*.mlu)
list(APPEND MEGDNN_INCLUDES "${PROJECT_SOURCE_DIR}/dnn/include")
list(APPEND MEGDNN_INCLUDES "${PROJECT_SOURCE_DIR}/dnn")
list(APPEND MEGDNN_INCLUDES "${PROJECT_BINARY_DIR}/genfiles")
bang_compile(BANG_OBJS "${BANG_SOURCES}" "${MEGDNN_INCLUDES}")
list(APPEND SOURCES ${BANG_OBJS})
endif()
if(MGE_WITH_ATLAS)
file(GLOB_RECURSE SOURCES_ atlas/*.cpp)
list(APPEND SOURCES ${SOURCES_})
list(APPEND LIBMEGDNN_DEF -DMEGDNN_WITH_ATLAS=1)
endif()
add_definitions(${LIBMEGDNN_DEF})
......@@ -97,8 +115,21 @@ else()
target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS})
endif()
if(MGE_WITH_ATLAS)
if (BUILD_SHARED_LIBS)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>)
else()
target_link_libraries(megdnn PRIVATE ${MGE_ATLAS_LIBS})
endif()
endif()
if(CMAKE_THREAD_LIBS_INIT)
target_link_libraries(megdnn PRIVATE Threads::Threads)
endif()
if(MGE_WITH_CAMBRICON)
target_link_libraries(megdnn PRIVATE ${BANG_OBJS} ${MGE_CAMBRICON_LIBS})
endif()
install(TARGETS megdnn EXPORT ${MGE_EXPORT_TARGETS})
/**
* \file dnn/src/atlas/checksum/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/atlas/checksum/opr_impl.h"
#include "src/atlas/utils.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "src/common/opr_delegate.h"
#include <cstring>
using namespace megdnn;
using namespace atlas;
size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout&) {
return 0;
}
ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data,
_megdnn_workspace workspace) {
check_exec(data.layout, workspace.size);
//! FIXME currently the cce programming interface is not so stable, here i
//! just allocate some memory of cpu here and compute the result in cpu
std::vector<uint8_t> cpu_data(data.layout.span().dist_byte(), 0);
megcoreDeviceHandle_t dev_handle;
megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle();
megcoreGetDeviceHandle(comp_handle, &dev_handle);
megcoreMemcpy(comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(),
megcoreMemcpyDeviceToHost);
megcoreSynchronize(comp_handle);
auto opr = inplace_cpu_handle()->create_operator<ChecksumForward>();
size_t workspace_size = opr->get_workspace_in_bytes(data.layout);
std::vector<uint8_t> cpu_workspace_data(workspace_size, 0);
Workspace cpu_workspace(
reinterpret_cast<dt_byte*>(cpu_workspace_data.data()),
cpu_workspace_data.size());
return opr->exec(TensorND{cpu_data.data(), data.layout}, cpu_workspace);
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/checksum/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace atlas {
class ChecksumForwardImpl final : public ChecksumForward {
public:
using ChecksumForward::ChecksumForward;
bool is_thread_safe() const override { return true; }
size_t get_workspace_in_bytes(const TensorLayout& data) override;
Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override;
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/handle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megcore_atlas.h"
#include "src/common/handle_impl.h"
#include "src/atlas/handle.h"
#include "src/atlas/checksum/opr_impl.h"
#include <acl/acl.h>
namespace megdnn {
namespace atlas {
HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
: HandleImplHelper(comp_handle, HandleType::ATLAS) {
// Get megcore device handle
megcoreDeviceHandle_t dev_handle;
megcoreGetDeviceHandle(comp_handle, &dev_handle);
int dev_id;
megcoreGetDeviceID(dev_handle, &dev_id);
m_device_id = dev_id;
megcore::getAtlasContext(comp_handle, &m_megcore_context);
}
HandleImpl::~HandleImpl() noexcept = default;
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
megdnn_throw("unsupported atlas opr");
return nullptr;
}
size_t HandleImpl::alignment_requirement() const {
//! because memcpyasync api requires that the memory is 128bytes alignment
return 64;
}
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop
} // namespace atlas
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/handle.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megcore_atlas.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"
#include "src/common/handle_impl.h"
#include "src/common/megcore/common/device_context.hpp"
#include "src/common/utils.h"
#include "src/atlas/megcore/device_context.hpp"
#include <atomic>
#include <mutex>
#include "acl/acl_rt.h"
namespace megdnn {
namespace atlas {
class HandleImpl : public HandleImplHelper {
public:
HandleImpl(megcoreComputingHandle_t computing_handle);
~HandleImpl() noexcept;
size_t alignment_requirement() const override;
template <typename Opr>
std::unique_ptr<Opr> create_operator();
const megcore::AtlasContext& megcore_context() const {
return m_megcore_context;
}
int device_id() const { return m_device_id; }
aclrtStream stream() const { return megcore_context().stream; }
//! global matmul opr
Checksum* checksum_opr() override final {
return get_helper_opr<Checksum, 0>(this);
}
private:
int m_device_id;
//! MegDNN handle does not manage the lifetime of cnrt queue.
megcore::AtlasContext m_megcore_context;
};
} // namespace atlas
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/megcore/atlas_computing_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megcore.h"
#include "src/atlas//megcore/computing_context.hpp"
#include "src/atlas/utils.h"
#include "src/common/utils.h"
using namespace megcore;
using namespace megcore::atlas;
AtlasComputingContext::AtlasComputingContext(megcoreDeviceHandle_t dev_handle,
unsigned int flags,
const AtlasContext& ctx)
: ComputingContext(dev_handle, flags),
m_own_stream{ctx.stream == nullptr},
m_ctx{ctx} {
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformAtlas);
if (m_own_stream) {
acl_check(aclrtCreateStream(&m_ctx.stream));
}
}
AtlasComputingContext::~AtlasComputingContext() {
if (m_own_stream) {
acl_check(aclrtDestroyStream(m_ctx.stream));
}
}
void AtlasComputingContext::memcpy(void* dst, const void* src,
size_t size_in_bytes,
megcoreMemcpyKind_t kind) {
aclrtMemcpyKind atlas_kind;
switch (kind) {
case megcoreMemcpyDeviceToHost:
atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST;
break;
case megcoreMemcpyHostToDevice:
atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE;
break;
case megcoreMemcpyDeviceToDevice:
atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
break;
default:
megdnn_throw("bad atlas memcpy kind");
}
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
atlas_kind, m_ctx.stream));
}
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
acl_check(aclrtSynchronizeStream(m_ctx.stream));
acl_check(aclrtMemset(dst, size_in_bytes, value, size_in_bytes));
}
void AtlasComputingContext::synchronize() {
acl_check(aclrtSynchronizeStream(m_ctx.stream));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/megcore/computing_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megcore_atlas.h"
#include "src/common/megcore/common/computing_context.hpp"
#include <acl/acl_rt.h>
namespace megcore {
namespace atlas {
class AtlasComputingContext final : public ComputingContext {
public:
AtlasComputingContext(megcoreDeviceHandle_t dev_handle, unsigned int flags,
const AtlasContext& ctx = {});
~AtlasComputingContext();
void memcpy(void* dst, const void* src, size_t size_in_bytes,
megcoreMemcpyKind_t kind) override;
void memset(void* dst, int value, size_t size_in_bytes) override;
void synchronize() override;
const AtlasContext& context() const { return m_ctx; }
aclrtStream stream() const { return context().stream; }
private:
bool m_own_stream;
AtlasContext m_ctx;
};
} // namespace atlas
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/megcore/device_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/atlas/megcore/device_context.hpp"
#include "megcore.h"
#include "src/atlas/utils.h"
#include "src/common/utils.h"
#include "acl/acl.h"
using namespace megcore;
using namespace atlas;
AtlasDeviceContext::AtlasDeviceContext(int device_id, unsigned int flags,
bool global_initialized)
: DeviceContext(megcorePlatformAtlas, device_id, flags) {
if (!global_initialized)
init_status.init();
int id = device_id;
if (id < 0) {
acl_check(aclrtGetDevice(&id));
}
}
AtlasDeviceContext::~AtlasDeviceContext() noexcept = default;
size_t AtlasDeviceContext::mem_alignment_in_bytes() const noexcept {
return 64;
}
void AtlasDeviceContext::activate() {
int id = device_id();
if (id >= 0) {
acl_check(aclrtSetDevice(id));
}
}
void AtlasDeviceContext::deactivate() {
int id = device_id();
megdnn_assert(id >= 0);
acl_check(aclrtResetDevice(id));
}
void* AtlasDeviceContext::malloc(size_t size_in_bytes) {
void* ptr;
acl_check(aclrtMalloc(&ptr, size_in_bytes, ACL_MEM_MALLOC_HUGE_FIRST));
return ptr;
}
void AtlasDeviceContext::free(void* ptr) {
acl_check(aclrtFree(ptr));
}
AtlasDeviceContext::InitStatus AtlasDeviceContext::init_status;
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/megcore/device_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/megcore/common/device_context.hpp"
#include "src/common/utils.h"
#include "megcore_atlas.h"
#include <mutex>
#include "acl/acl.h"
namespace megcore {
namespace atlas {
class AtlasDeviceContext : public DeviceContext {
public:
AtlasDeviceContext(int device_id, unsigned int flags,
bool global_initialized = false);
~AtlasDeviceContext() noexcept;
size_t mem_alignment_in_bytes() const noexcept override;
void activate() override;
void deactivate() override;
void* malloc(size_t size_in_bytes) override;
void free(void* ptr) override;
struct InitStatus {
bool initialized;
std::mutex mtx;
InitStatus() : initialized{false} {}
void init() {
std::lock_guard<std::mutex> guard{mtx};
if (!initialized) {
auto err = aclInit(nullptr);
initialized = err == ACL_ERROR_NONE;
megdnn_assert(initialized,
"aclrt initialize failed: (acl:%d): %s",
static_cast<int>(err),
megcore::atlas::get_error_str(err));
}
}
~InitStatus() {
if (initialized) {
initialized = false;
}
}
};
static InitStatus init_status;
};
} // namespace atlas
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/megcore/public_api/computing.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megcore_atlas.h"
#include "src/atlas/megcore/computing_context.hpp"
#include "src/atlas/megcore/device_context.hpp"
#include "src/common/megcore/public_api/computing.hpp"
#include "src/common/megcore/public_api/device.hpp"
#include "src/common/utils.h"
using namespace megcore;
megcoreStatus_t megcore::createAtlasDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized) {
auto content = megdnn::make_unique<atlas::AtlasDeviceContext>(
deviceID, flags, global_initialized);
auto& ctx = *devHandle;
ctx = new megcoreDeviceContext;
ctx->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::createComputingHandleWithAtlasContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const AtlasContext& ctx) {
MEGDNN_MARK_USED_VAR(flags);
megdnn_assert(flags == 0);
auto content = megdnn::make_unique<atlas::AtlasComputingContext>(
devHandle, flags, ctx);
auto& H = *compHandle;
H = new megcoreComputingContext;
H->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::getAtlasContext(megcoreComputingHandle_t handle,
AtlasContext* ctx) {
auto&& H = handle;
megdnn_assert(H);
megcoreDeviceHandle_t dev_handle = H->content->dev_handle();
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformAtlas);
auto context = static_cast<megcore::atlas::AtlasComputingContext*>(
H->content.get());
*ctx = context->context();
return megcoreSuccess;
}
const char* megcore::atlas::get_error_str(aclError error) {
#define ERROR(_err) \
case _err: \
return #_err;
switch (error) {
ERROR(ACL_ERROR_NONE);
ERROR(ACL_ERROR_INVALID_PARAM);
ERROR(ACL_ERROR_UNINITIALIZE);
ERROR(ACL_ERROR_REPEAT_INITIALIZE);
ERROR(ACL_ERROR_INVALID_FILE);
ERROR(ACL_ERROR_WRITE_FILE);
ERROR(ACL_ERROR_INVALID_FILE_SIZE);
ERROR(ACL_ERROR_PARSE_FILE);
ERROR(ACL_ERROR_FILE_MISSING_ATTR);
ERROR(ACL_ERROR_FILE_ATTR_INVALID);
ERROR(ACL_ERROR_INVALID_DUMP_CONFIG);
ERROR(ACL_ERROR_INVALID_PROFILING_CONFIG);
ERROR(ACL_ERROR_INVALID_MODEL_ID);
ERROR(ACL_ERROR_DESERIALIZE_MODEL);
ERROR(ACL_ERROR_PARSE_MODEL);
ERROR(ACL_ERROR_READ_MODEL_FAILURE);
ERROR(ACL_ERROR_MODEL_SIZE_INVALID);
ERROR(ACL_ERROR_MODEL_MISSING_ATTR);
ERROR(ACL_ERROR_MODEL_INPUT_NOT_MATCH);
ERROR(ACL_ERROR_MODEL_OUTPUT_NOT_MATCH);
ERROR(ACL_ERROR_MODEL_NOT_DYNAMIC);
ERROR(ACL_ERROR_OP_TYPE_NOT_MATCH);
ERROR(ACL_ERROR_OP_INPUT_NOT_MATCH);
ERROR(ACL_ERROR_OP_OUTPUT_NOT_MATCH);
ERROR(ACL_ERROR_OP_ATTR_NOT_MATCH);
ERROR(ACL_ERROR_OP_NOT_FOUND);
ERROR(ACL_ERROR_OP_LOAD_FAILED);
ERROR(ACL_ERROR_UNSUPPORTED_DATA_TYPE);
ERROR(ACL_ERROR_FORMAT_NOT_MATCH);
ERROR(ACL_ERROR_BIN_SELECTOR_NOT_REGISTERED);
ERROR(ACL_ERROR_KERNEL_NOT_FOUND);
ERROR(ACL_ERROR_BIN_SELECTOR_ALREADY_REGISTERED);
ERROR(ACL_ERROR_KERNEL_ALREADY_REGISTERED);
ERROR(ACL_ERROR_INVALID_QUEUE_ID);
ERROR(ACL_ERROR_REPEAT_SUBSCRIBE);
ERROR(ACL_ERROR_STREAM_NOT_SUBSCRIBE);
ERROR(ACL_ERROR_THREAD_NOT_SUBSCRIBE);
ERROR(ACL_ERROR_WAIT_CALLBACK_TIMEOUT);
ERROR(ACL_ERROR_REPEAT_FINALIZE);
ERROR(ACL_ERROR_NOT_STATIC_AIPP);
ERROR(ACL_ERROR_BAD_ALLOC);
ERROR(ACL_ERROR_API_NOT_SUPPORT);
ERROR(ACL_ERROR_INVALID_DEVICE);
ERROR(ACL_ERROR_MEMORY_ADDRESS_UNALIGNED);
ERROR(ACL_ERROR_RESOURCE_NOT_MATCH);
ERROR(ACL_ERROR_INVALID_RESOURCE_HANDLE);
ERROR(ACL_ERROR_FEATURE_UNSUPPORTED);
ERROR(ACL_ERROR_STORAGE_OVER_LIMIT);
ERROR(ACL_ERROR_INTERNAL_ERROR);
ERROR(ACL_ERROR_FAILURE);
ERROR(ACL_ERROR_GE_FAILURE);
ERROR(ACL_ERROR_RT_FAILURE);
ERROR(ACL_ERROR_DRV_FAILURE);
ERROR(ACL_ERROR_PROFILING_FAILURE);
default:
return "unknown error";
}
#undef ERROR
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/atlas/utils.h"
#include "megcore_atlas.h"
#include "src/common/utils.h"
using namespace megdnn;
using namespace atlas;
void atlas::__throw_acl_error__(aclError err, const char* msg) {
auto s = ssprintf("acl return %s(%d) occurred; expr: %s",
megcore::atlas::get_error_str(err), int(err), msg);
megdnn_throw(s.c_str());
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/atlas/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/handle.h"
#include "src/atlas/handle.h"
#include <acl/acl_base.h>
#define acl_check(_x) \
do { \
aclError _ret = (_x); \
if (_ret != ACL_ERROR_NONE) { \
::megdnn::atlas::__throw_acl_error__(_ret, #_x); \
} \
} while (0)
namespace megdnn {
namespace atlas {
inline HandleImpl* concrete_handle(Handle* handle) {
return static_cast<atlas::HandleImpl*>(handle);
}
//! Error handling funcions
MEGDNN_NORETURN void __throw_acl_error__(aclError err, const char* msg);
} // namespace atlas
} // namespace megdnn
// vim: syntax=cpp.doxygen
load("//brain/megbrain/dnn:flags.bzl", "megdnn_opts")
load("@megvii3//tools/build_rules:bangc.bzl", "bangc_library")
package(default_visibility = ["//brain/megbrain/dnn:__subpackages__"])
bangc_library(
name = "bangc_kernels",
srcs = glob([
"**/*.mlu",
]) + [
"//brain/megbrain/dnn:src/common/utils.cuh",
],
hdrs = glob([
"**/*.mlu.h",
]),
deps = [
"//brain/megbrain/dnn:public_headers",
],
copts = megdnn_opts + [
"-Ibrain/megbrain/dnn",
],
)
filegroup(
name = "cambricon_backend_files",
srcs = glob([
"**/*.cpp",
"**/*.h",
"**/*.hpp",
]),
)
/**
* \file dnn/src/cambricon/checksum/checksum.mlu.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cambricon/utils.mlu.h"
#ifdef __cplusplus
extern "C" {
#endif
void checksum_kernel_union1(uint32_t* dst, const uint32_t* src, int num_elems);
void checksum_kernel_union4(uint32_t* dst, const uint32_t* src, int num_elems);
#ifdef __cplusplus
}
#endif
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/checksum_kernel_union1.mlu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "checksum.mlu.h"
#include "cnsccl.h"
#include "mlu.h"
#define CLUSTER_DIM 1
#define CORE_DIM 4
#define STRIDE 1024
__mlu_entry__ void checksum_kernel_union1(uint32_t* dst, uint32_t* src,
int nr_elems) {
__nram__ uint32_t sum = 0;
__nram__ uint32_t val[STRIDE];
const uint32_t TASK_DIM = CLUSTER_DIM * CORE_DIM;
__mlu_shared__ uint32_t partial_sum[TASK_DIM];
int task_stride = STRIDE;
int start_offset = taskId * task_stride;
int global_stride = taskDim * task_stride;
for (int task_offset = start_offset; task_offset < nr_elems;
task_offset += global_stride) {
int end_offset = task_offset + task_stride;
end_offset = end_offset > nr_elems ? nr_elems : end_offset;
int copy_elems = end_offset - task_offset;
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t),
GDRAM2NRAM);
for (int i = 0; i < copy_elems; i++) {
sum = sum + val[i] * (task_offset + i + 1);
}
}
partial_sum[taskId] = sum;
__sync_cluster();
if (taskId == 0) {
uint32_t res = 0;
for (int i = 0; i < taskDim; i++) {
res += partial_sum[i];
}
dst[0] = res;
}
}
#undef CLUSTER_DIM
#undef CORE_DIM
#undef STRIDE
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/checksum_kernel_union4.mlu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "checksum.mlu.h"
#include "cnsccl.h"
#include "mlu.h"
#define CLUSTER_DIM 4
#define CORE_DIM 4
#define STRIDE 1024
__mlu_entry__ void checksum_kernel_union4(uint32_t* dst, uint32_t* src,
int nr_elems) {
__nram__ uint32_t sum = 0;
__nram__ uint32_t val[STRIDE];
__mlu_shared__ uint32_t partial_sum_send[CORE_DIM];
__mlu_shared__ uint32_t partial_sum_recv[CLUSTER_DIM];
int task_stride = STRIDE;
int start_offset = taskId * task_stride;
int global_stride = taskDim * task_stride;
for (int task_offset = start_offset; task_offset < nr_elems;
task_offset += global_stride) {
int end_offset = task_offset + task_stride;
end_offset = end_offset > nr_elems ? nr_elems : end_offset;
int copy_elems = end_offset - task_offset;
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t),
GDRAM2NRAM);
for (int i = 0; i < copy_elems; i++) {
sum = sum + val[i] * (task_offset + i + 1);
}
}
partial_sum_send[coreId] = sum;
__sync_cluster();
if (coreId == 0) {
for (int i = 1; i < CORE_DIM; ++i) {
partial_sum_send[0] += partial_sum_send[i];
}
}
__sync_all();
cnscclGather((void*)&partial_sum_send, (void*)&partial_sum_recv, 1,
cnscclInt, 0);
__sync_all();
if (clusterId == 0 && coreId == 0) {
uint32_t res = 0;
for (int i = 0; i < CLUSTER_DIM; ++i) {
res += partial_sum_recv[i];
}
dst[0] = res;
}
}
#undef CLUSTER_DIM
#undef CORE_DIM
#undef STRIDE
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cambricon/checksum/checksum.mlu.h"
#include "src/cambricon/checksum/opr_impl.h"
#include "src/cambricon/utils.h"
#include <algorithm>
using namespace megdnn;
using namespace cambricon;
namespace {
void bang_c_wrapper(uint32_t* dst, const uint32_t* src, int nr_elems,
cnrtQueue_t queue, cnrtCoreVersion_t core_version) {
cnrtKernelParamsBuffer_t params;
cnrt_check(cnrtGetKernelParamsBuffer(&params));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*)));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &src, sizeof(uint32_t*)));
cnrt_check(cnrtKernelParamsBufferAddParam(params, &nr_elems, sizeof(int)));
if (core_version == CNRT_MLU270) {
cnrtDim3_t dim;
dim.x = 16;
dim.y = 1;
dim.z = 1;
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4;
cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union4, dim,
params, c, queue));
} else if (core_version == CNRT_MLU220) {
cnrtDim3_t dim;
dim.x = 4;
dim.y = 1;
dim.z = 1;
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1;
cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union1, dim,
params, c, queue));
}
after_kernel_launch();
cnrt_check(cnrtDestroyKernelParamsBuffer(params));
}
} // namespace
size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data */) {
size_t ws_size = sizeof(ChecksumForward::Result::checksum);
return ws_size;
}
ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data,
_megdnn_workspace workspace) {
Result result;
memset(&result, 0, sizeof(result));
check_exec(data.layout, workspace.size);
auto queue = cnrt_queue(handle());
auto ptr = static_cast<uint8_t*>(data.raw_ptr);
size_t size_all = data.layout.shape[0],
size_ints = size_all / sizeof(uint32_t);
auto last_val_size = std::min<size_t>(size_all, 4);
cnrt_check(cnrtMemcpyAsync(&result.last_val, ptr + size_all - last_val_size,
last_val_size, queue,
CNRT_MEM_TRANS_DIR_DEV2HOST));
if (size_ints) {
auto&& device_info = current_device_info();
bang_c_wrapper(reinterpret_cast<uint32_t*>(workspace.raw_ptr),
static_cast<uint32_t*>(data.raw_ptr), size_ints, queue,
device_info.core_version);
cnrt_check(cnrtMemcpyAsync(&result.checksum, workspace.raw_ptr,
sizeof(result.checksum), queue,
CNRT_MEM_TRANS_DIR_DEV2HOST));
}
cnrt_check(cnrtSyncQueue(queue));
return result;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/checksum/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cambricon/utils.h"
namespace megdnn {
namespace cambricon {
class ChecksumForwardImpl final : public ChecksumForward {
public:
using ChecksumForward::ChecksumForward;
size_t get_workspace_in_bytes(const TensorLayout&) override;
bool is_thread_safe() const override { return true; }
Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override;
};
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/handle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/handle_impl.h"
#include "src/common/version_symbol.h"
#include "src/cambricon/handle.h"
#include "src/cambricon/utils.h"
#include "src/cambricon/checksum/opr_impl.h"
#include <cnrt.h>
namespace megdnn {
namespace cambricon {
HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
: HandleImplHelper(comp_handle, HandleType::CAMBRICON) {
// Get megcore device handle
megcoreDeviceHandle_t dev_handle;
megcoreGetDeviceHandle(comp_handle, &dev_handle);
int dev_id;
megcoreGetDeviceID(dev_handle, &dev_id);
unsigned int dev_num;
cnrt_check(cnrtGetDeviceCount(&dev_num));
MEGDNN_MARK_USED_VAR(dev_num);
// check validity of device_id
megdnn_assert(dev_id >= 0 && static_cast<unsigned int>(dev_id) < dev_num);
m_device_id = dev_id;
cnrt_check(cnrtGetDeviceInfo(&m_device_info, dev_id));
megcore::getCambriconContext(comp_handle, &m_megcore_context);
}
HandleImpl::~HandleImpl() noexcept = default;
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
megdnn_throw("unsupported cambricon opr");
return nullptr;
}
size_t HandleImpl::alignment_requirement() const {
return 1;
}
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop
} // namespace cambricon
} // namespace megdnn
MEGDNN_VERSION_SYMBOL3(CNRT, CNRT_MAJOR_VERSION, CNRT_MINOR_VERSION,
CNRT_PATCH_VERSION);
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/handle.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cambricon.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"
#include "src/common/handle_impl.h"
#include "src/common/utils.h"
#include <atomic>
#include <mutex>
#include <cnrt.h>
namespace megdnn {
namespace cambricon {
class HandleImpl : public HandleImplHelper {
public:
HandleImpl(megcoreComputingHandle_t computing_handle);
~HandleImpl() noexcept;
size_t alignment_requirement() const override;
const cnrtDeviceInfo_t& device_info() const { return m_device_info; }
template <typename Opr>
std::unique_ptr<Opr> create_operator();
const megcore::CambriconContext& megcore_context() const {
return m_megcore_context;
}
int device_id() const { return m_device_id; }
cnrtQueue_t queue() const { return megcore_context().queue; }
//! global matmul opr
Checksum* checksum_opr() override final {
return get_helper_opr<Checksum, 0>(this);
}
private:
int m_device_id;
//! MegDNN handle does not manage the lifetime of cnrt queue.
megcore::CambriconContext m_megcore_context;
cnrtDeviceInfo_t m_device_info;
};
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_computing_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore.h"
#include "src/cambricon/utils.h"
#include "src/common/utils.h"
#include "src/cambricon/megcore/cambricon_computing_context.hpp"
using namespace megcore;
using namespace megcore::cambricon;
CambriconComputingContext::CambriconComputingContext(
megcoreDeviceHandle_t dev_handle, unsigned int flags,
const CambriconContext& ctx)
: ComputingContext(dev_handle, flags),
own_queue{ctx.queue == nullptr},
context_{ctx} {
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformCambricon);
if (own_queue) {
cnrt_check(cnrtCreateQueue(&context_.queue));
}
}
CambriconComputingContext::~CambriconComputingContext() {
if (own_queue) {
cnrt_check(cnrtDestroyQueue(context_.queue));
}
}
void CambriconComputingContext::memcpy(void* dst, const void* src,
size_t size_in_bytes,
megcoreMemcpyKind_t kind) {
cnrtMemTransDir_t dir;
switch (kind) {
case megcoreMemcpyDeviceToHost:
dir = CNRT_MEM_TRANS_DIR_DEV2HOST;
break;
case megcoreMemcpyHostToDevice:
dir = CNRT_MEM_TRANS_DIR_HOST2DEV;
break;
case megcoreMemcpyDeviceToDevice:
dir = CNRT_MEM_TRANS_DIR_DEV2DEV;
break;
default:
megdnn_throw(megdnn_mangle("bad cnrt mem trans dir"));
}
if (kind == megcoreMemcpyDeviceToDevice) {
cnrt_check(cnrtSyncQueue(context_.queue));
cnrt_check(cnrtMemcpy(dst, const_cast<void*>(src), size_in_bytes, dir));
return;
}
cnrt_check(cnrtMemcpyAsync(dst, const_cast<void*>(src), size_in_bytes,
context_.queue, dir));
}
void CambriconComputingContext::memset(void* dst, int value,
size_t size_in_bytes) {
cnrt_check(cnrtSyncQueue(context_.queue));
cnrt_check(cnrtMemset(dst, value, size_in_bytes));
}
void CambriconComputingContext::synchronize() {
cnrt_check(cnrtSyncQueue(context_.queue));
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_computing_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cambricon.h"
#include "src/common/megcore/common/computing_context.hpp"
namespace megcore {
namespace cambricon {
class CambriconComputingContext final : public ComputingContext {
public:
CambriconComputingContext(megcoreDeviceHandle_t dev_handle,
unsigned int flags,
const CambriconContext& ctx = {});
~CambriconComputingContext();
void memcpy(void* dst, const void* src, size_t size_in_bytes,
megcoreMemcpyKind_t kind) override;
void memset(void* dst, int value, size_t size_in_bytes) override;
void synchronize() override;
const CambriconContext& context() const { return context_; }
cnrtQueue_t queue() const { return context().queue; }
private:
bool own_queue;
CambriconContext context_;
};
} // namespace cambricon
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_device_context.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore.h"
#include "src/cambricon/utils.h"
#include "src/common/utils.h"
#include "src/cambricon/megcore/cambricon_device_context.hpp"
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define CNRT_VERSION_STR \
STR(CNRT_MAJOR_VERSION) \
"." STR(CNRT_MINOR_VERSION) "." STR(CNRT_PATCH_VERSION)
#pragma message "compile with cnrt " CNRT_VERSION_STR " "
#undef STR_HELPER
#undef STR
using namespace megcore;
using namespace cambricon;
CambriconDeviceContext::CambriconDeviceContext(int device_id,
unsigned int flags,
bool global_initialized)
: DeviceContext(megcorePlatformCambricon, device_id, flags) {
if (!global_initialized)
init_status.init();
unsigned int version;
cnrt_check(cnrtGetVersion(&version));
megdnn_assert(version == CNRT_VERSION,
"megcore compiled with cnrt %d, get %d at runtime",
CNRT_VERSION, version);
unsigned int dev_num;
cnrt_check(cnrtGetDeviceCount(&dev_num));
MEGDNN_MARK_USED_VAR(dev_num);
// check validity of device_id
megdnn_assert(device_id >= 0 &&
static_cast<unsigned int>(device_id) < dev_num);
cnrt_check(cnrtGetDeviceInfo(&device_info, device_id));
}
CambriconDeviceContext::~CambriconDeviceContext() noexcept = default;
size_t CambriconDeviceContext::mem_alignment_in_bytes() const noexcept {
return 1;
}
void CambriconDeviceContext::activate() {
int id = device_id();
cnrtDev_t dev;
cnrt_check(cnrtGetDeviceHandle(&dev, id));
cnrt_check(cnrtSetCurrentDevice(dev));
}
void* CambriconDeviceContext::malloc(size_t size_in_bytes) {
void* ptr;
cnrt_check(cnrtMalloc(&ptr, size_in_bytes));
return ptr;
}
void CambriconDeviceContext::free(void* ptr) {
cnrt_check(cnrtFree(ptr));
}
CambriconDeviceContext::InitStatus CambriconDeviceContext::init_status;
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/cambricon_device_context.hpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "megcore_cambricon.h"
#include "src/common/megcore/common/device_context.hpp"
#include "src/common/utils.h"
namespace megcore {
namespace cambricon {
class CambriconDeviceContext : public DeviceContext {
public:
CambriconDeviceContext(int device_id, unsigned int flags,
bool global_initialized = false);
~CambriconDeviceContext() noexcept;
size_t mem_alignment_in_bytes() const noexcept override;
void activate() override;
void* malloc(size_t size_in_bytes) override;
void free(void* ptr) override;
struct InitStatus {
bool initialized;
std::mutex mtx;
InitStatus() : initialized{false} {}
void init() {
std::lock_guard<std::mutex> guard{mtx};
if (!initialized) {
auto cnrt_err = cnrtInit(0);
initialized = cnrt_err == CNRT_RET_SUCCESS;
megdnn_assert(initialized, "cnrt initialize failed: (cnrt:%d)",
static_cast<int>(cnrt_err));
}
}
~InitStatus() {
if (initialized) {
cnrtDestroy();
initialized = false;
}
}
};
static InitStatus init_status;
private:
cnrtDeviceInfo_t device_info;
};
} // namespace cambricon
} // namespace megcore
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/megcore/public_api/computing.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megcore_cambricon.h"
#include "src/cambricon/megcore/cambricon_computing_context.hpp"
#include "src/cambricon/megcore/cambricon_device_context.hpp"
#include "src/common/megcore/public_api/computing.hpp"
#include "src/common/megcore/public_api/device.hpp"
#include "src/common/utils.h"
using namespace megcore;
megcoreStatus_t megcore::createDeviceHandleWithGlobalInitStatus(
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags,
bool global_initialized) {
auto content = megdnn::make_unique<cambricon::CambriconDeviceContext>(
deviceID, flags, global_initialized);
auto& ctx = *devHandle;
ctx = new megcoreDeviceContext;
ctx->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx) {
auto content = megdnn::make_unique<cambricon::CambriconComputingContext>(
devHandle, flags, ctx);
auto& H = *compHandle;
H = new megcoreComputingContext;
H->content = std::move(content);
return megcoreSuccess;
}
megcoreStatus_t megcore::getCambriconContext(megcoreComputingHandle_t handle,
CambriconContext* ctx) {
auto&& H = handle;
megdnn_assert(H);
megcoreDeviceHandle_t dev_handle = H->content->dev_handle();
megcorePlatform_t platform;
megcoreGetPlatform(dev_handle, &platform);
megdnn_assert(platform == megcorePlatformCambricon);
auto context = static_cast<megcore::cambricon::CambriconComputingContext*>(
H->content.get());
*ctx = context->context();
return megcoreSuccess;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cambricon/utils.h"
#include "src/cambricon/utils.mlu.h"
#include "src/cambricon/handle.h"
#include "src/common/utils.h"
#include <mutex>
#include <unordered_map>
using namespace megdnn;
using namespace cambricon;
namespace {
struct DeviceInfoRecord {
bool init = false;
cnrtDeviceInfo_t device_info;
std::mutex mtx;
};
std::unordered_map<cnrtDev_t, int> dev2device_id;
std::mutex dev2device_id_mtx;
constexpr int MAX_NR_DEVICE = 64;
DeviceInfoRecord device_info_rec[MAX_NR_DEVICE];
} // namespace
void cambricon::__throw_cnrt_error__(cnrtRet_t err, const char* msg) {
auto s = ssprintf("cnrt return %s(%d) occurred; expr: %s",
cnrtGetErrorStr(err), int(err), msg);
megdnn_throw(s.c_str());
}
cnrtDeviceInfo_t cambricon::current_device_info() {
static bool dev2device_id_init = false;
{
std::lock_guard<std::mutex> lock(dev2device_id_mtx);
if (!dev2device_id_init) {
unsigned int dev_num = 0;
cnrt_check(cnrtGetDeviceCount(&dev_num));
for (unsigned int dev_id = 0; dev_id < dev_num; ++dev_id) {
cnrtDev_t dev;
cnrt_check(cnrtGetDeviceHandle(&dev, dev_id));
dev2device_id[dev] = dev_id;
}
dev2device_id_init = true;
}
}
cnrtDev_t dev;
cnrt_check(cnrtGetCurrentDevice(&dev));
{
std::lock_guard<std::mutex> lock(dev2device_id_mtx);
int dev_id = dev2device_id.at(dev);
auto& rec = device_info_rec[dev_id];
{
std::lock_guard<std::mutex> lock(rec.mtx);
if (!rec.init) {
cnrt_check(cnrtGetDeviceInfo(&rec.device_info, dev_id));
rec.init = true;
}
}
return rec.device_info;
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megcore_cdefs.h"
#include "megdnn/handle.h"
#include "src/cambricon/utils.mlu.h"
#include "src/common/utils.h"
#include "src/cambricon/handle.h"
#include <cnrt.h>
namespace megdnn {
namespace cambricon {
static inline HandleImpl* concrete_handle(Handle* handle) {
return static_cast<cambricon::HandleImpl*>(handle);
}
static inline cnrtQueue_t cnrt_queue(Handle* handle) {
return concrete_handle(handle)->queue();
}
//! get device info of current active device
cnrtDeviceInfo_t current_device_info();
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cambricon/utils.mlu.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.cuh"
#include <stdint.h>
#include <cnrt.h>
#define cnrt_check(_x) \
do { \
cnrtRet_t _ret = (_x); \
if (_ret != CNRT_RET_SUCCESS) { \
::megdnn::cambricon::__throw_cnrt_error__(_ret, #_x); \
} \
} while (0)
#define after_kernel_launch() \
do { \
cnrt_check(cnrtGetLastErr()); \
} while (0)
namespace megdnn {
namespace cambricon {
//! Error handling funcions
MEGDNN_NORETURN void __throw_cnrt_error__(cnrtRet_t err, const char* msg);
} // namespace cambricon
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -36,6 +36,14 @@
#endif
#if MEGDNN_WITH_CAMBRICON
#include "src/cambricon/handle.h"
#endif
#ifdef MEGDNN_WITH_ATLAS
#include "src/atlas/handle.h"
#endif
using namespace megdnn;
MIDOUT_DECL(HandlePlatform);
......@@ -83,6 +91,20 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
}
}
MIDOUT_END();
#endif
}
else if (platform == megcorePlatformCambricon) {
#if MEGDNN_WITH_CAMBRICON
return make_unique<cambricon::HandleImpl>(computing_handle);
#else
return nullptr;
#endif
}
else if (platform == megcorePlatformAtlas) {
#if MEGDNN_WITH_ATLAS
return make_unique<atlas::HandleImpl>(computing_handle);
#else
return nullptr;
#endif
}
else {
......@@ -94,6 +116,7 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
return nullptr;
#endif
}
return nullptr;
}
......@@ -166,6 +189,12 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
#endif // !MEGDNN_NAIVE
#if MEGDNN_WITH_CUDA
CASE(CUDA,cuda);
#endif
#if MEGDNN_WITH_ATLAS
CASE(ATLAS, atlas);
#endif
#if MEGDNN_WITH_CAMBRICON
CASE(CAMBRICON, cambricon);
#endif
default:
megdnn_throw(megdnn_mangle("bad handle type"));
......
......@@ -18,6 +18,14 @@
#endif
#if MEGDNN_WITH_CAMBRICON
#include "src/cambricon/megcore/cambricon_computing_context.hpp"
#endif
#if MEGDNN_WITH_ATLAS
#include "src/atlas/megcore/computing_context.hpp"
#endif
using namespace megcore;
using namespace megdnn;
......@@ -32,6 +40,15 @@ std::unique_ptr<ComputingContext> ComputingContext::make(
#if MEGDNN_WITH_CUDA
case megcorePlatformCUDA:
return make_unique<cuda::CUDAComputingContext>(dev_handle, flags);
#endif
#if MEGDNN_WITH_CAMBRICON
case megcorePlatformCambricon:
return make_unique<cambricon::CambriconComputingContext>(dev_handle,
flags);
#endif
#if MEGDNN_WITH_ATLAS
case megcorePlatformAtlas:
return make_unique<atlas::AtlasComputingContext>(dev_handle, flags);
#endif
default:
megdnn_throw("bad platform");
......
......@@ -15,6 +15,13 @@
#if MEGDNN_WITH_CUDA
#include "src/cuda/megcore/cuda_device_context.hpp"
#endif
#if MEGDNN_WITH_CAMBRICON
#include "src/cambricon/megcore/cambricon_device_context.hpp"
#endif
#if MEGDNN_WITH_ATLAS
#include "src/atlas/megcore/device_context.hpp"
#endif
using namespace megcore;
using namespace megdnn;
......@@ -28,6 +35,15 @@ std::unique_ptr<DeviceContext> DeviceContext::make(megcorePlatform_t platform,
#if MEGDNN_WITH_CUDA
case megcorePlatformCUDA:
return make_unique<cuda::CUDADeviceContext>(deviceID, flags);
#endif
#if MEGDNN_WITH_CAMBRICON
case megcorePlatformCambricon:
return make_unique<cambricon::CambriconDeviceContext>(deviceID,
flags);
#endif
#if MEGDNN_WITH_ATLAS
case megcorePlatformAtlas:
return make_unique<atlas::AtlasDeviceContext>(deviceID, flags);
#endif
default:
megdnn_throw("bad platform");
......
......@@ -26,6 +26,16 @@ if(MGE_WITH_CUDA)
endif()
if(MGE_WITH_CAMBRICON)
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp)
list(APPEND SOURCES ${SOURCES_})
endif()
if(MGE_WITH_ATLAS)
file(GLOB_RECURSE SOURCES_ atlas/*.cpp)
list(APPEND SOURCES ${SOURCES_})
endif()
add_executable(megdnn_test ${SOURCES})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
......
/**
* \file dnn/test/atlas/checksum.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "test/atlas/fixture.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(ATLAS, CHECKSUM_FORWARD) {
auto atlas_opr = handle_atlas()->create_operator<megdnn::Checksum>(),
naive_opr = handle_naive()->create_operator<megdnn::Checksum>();
std::mt19937 rng(std::random_device{}());
for (size_t size :
{3, 8, 4 * 4 * 1024, 12345, 1024 * 1024, 1024 * 1024 * 10}) {
auto aligned_size = size + ((512 - size % 512) % 512);
auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) {
TensorND tensor;
tensor.raw_ptr = ptr;
tensor.layout.init_contiguous_stride({size});
tensor.layout.dtype = dtype::Byte();
WorkspaceWrapper workspace(
handle_atlas(), opr->get_workspace_in_bytes(tensor.layout));
if (log_size) {
printf("checksum(%zu): workspace=%zu\n", size,
workspace.workspace().size);
}
return opr->exec(tensor, workspace.workspace());
};
std::vector<uint8_t> buf(aligned_size);
for (size_t i = 0; i < size; ++i)
buf[i] = 1;
auto run_offsset = [&](size_t offset) {
void* dev_ptr = megdnn_malloc(handle_atlas(), buf.size() + offset);
void* dev_buf = static_cast<char*>(dev_ptr) + offset;
Checksum::Result res_cambricon[2], res_naive[2];
for (int change_last = 0; change_last < 2; ++change_last) {
if (change_last)
++buf[size - 1];
megdnn_memcpy_H2D(handle_atlas(), dev_buf, buf.data(), size);
res_cambricon[change_last] =
run(atlas_opr.get(), dev_buf, !change_last);
res_naive[change_last] =
run(naive_opr.get(), buf.data(), false);
}
megdnn_free(handle_atlas(), dev_ptr);
ASSERT_EQ(res_naive[0], res_cambricon[0])
<< "failed for size " << size;
ASSERT_EQ(res_naive[1], res_cambricon[1]);
ASSERT_NE(res_cambricon[0], res_cambricon[1]);
};
for (size_t i = 0; i < 8; ++i) {
run_offsset(i);
}
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/atlas/fixture.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/atlas/fixture.h"
#include "src/atlas/handle.h"
#include "src/atlas/megcore/device_context.hpp"
#include "src/atlas/utils.h"
#include "test/common/memory_manager.h"
#include "test/common/random_state.h"
#include "test/common/utils.h"
#include "acl/acl.h"
#include <cstdlib>
using namespace megdnn;
using namespace test;
void ATLAS::SetUp() {
RandomState::reset();
// use card 0
megcore_check(
megcoreCreateDeviceHandle(&m_dev_handle, megcorePlatformAtlas, 0));
megcoreActivate(m_dev_handle);
megcoreComputingHandle_t comp_handle;
megcore_check(megcoreCreateComputingHandle(&comp_handle, m_dev_handle));
m_handle_atlas = Handle::make(comp_handle);
megdnn_assert(m_handle_atlas);
}
Handle* ATLAS::handle_naive() {
if (!m_handle_naive)
m_handle_naive = create_cpu_handle(2);
return m_handle_naive.get();
}
void ATLAS::TearDown() {
m_handle_naive.reset();
m_handle_atlas.reset();
MemoryManagerHolder::instance()->clear();
megcoreDeactivate(m_dev_handle);
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/atlas/fixture.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <gtest/gtest.h>
#include "megcore_cdefs.h"
#include "megdnn/handle.h"
#include <memory>
namespace megdnn {
namespace test {
class ATLAS : public ::testing::Test {
public:
void SetUp() override;
void TearDown() override;
Handle* handle_atlas() { return m_handle_atlas.get(); }
Handle* handle_naive();
private:
std::unique_ptr<Handle> m_handle_naive;
std::unique_ptr<Handle> m_handle_atlas;
megcoreDeviceHandle_t m_dev_handle;
};
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/checksum.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "test/cambricon/fixture.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(CAMBRICON, CHECKSUM_FORWARD) {
auto cambricon_opr =
handle_cambricon()->create_operator<megdnn::Checksum>(),
naive_opr = handle_naive()->create_operator<megdnn::Checksum>();
std::mt19937 rng(std::random_device{}());
for (size_t size :
{3, 8, 4 * 4 * 1024, 12345, 1024 * 1024, 1024 * 1024 * 10}) {
auto aligned_size = size + ((512 - size % 512) % 512);
auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) {
TensorND tensor;
tensor.raw_ptr = ptr;
tensor.layout.init_contiguous_stride({size});
tensor.layout.dtype = dtype::Byte();
WorkspaceWrapper workspace(
handle_cambricon(),
opr->get_workspace_in_bytes(tensor.layout));
if (log_size) {
printf("checksum(%zu): workspace=%zu\n", size,
workspace.workspace().size);
}
return opr->exec(tensor, workspace.workspace());
};
std::vector<uint8_t> buf(aligned_size);
for (size_t i = 0; i < size; ++i)
buf[i] = 1;
auto run_offsset = [&](size_t offset) {
void* dev_ptr =
megdnn_malloc(handle_cambricon(), buf.size() + offset);
void* dev_buf = static_cast<char*>(dev_ptr) + offset;
Checksum::Result res_cambricon[2], res_naive[2];
for (int change_last = 0; change_last < 2; ++change_last) {
if (change_last)
++buf[size - 1];
megdnn_memcpy_H2D(handle_cambricon(), dev_buf, buf.data(),
size);
res_cambricon[change_last] =
run(cambricon_opr.get(), dev_buf, !change_last);
res_naive[change_last] =
run(naive_opr.get(), buf.data(), false);
}
megdnn_free(handle_cambricon(), dev_ptr);
ASSERT_EQ(res_naive[0], res_cambricon[0])
<< "failed for size " << size;
ASSERT_EQ(res_naive[1], res_cambricon[1]);
ASSERT_NE(res_cambricon[0], res_cambricon[1]);
};
for (size_t i = 0; i < 8; ++i) {
run_offsset(i);
}
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/fixture.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/cambricon/fixture.h"
#include "src/cambricon/handle.h"
#include "src/cambricon/utils.h"
#include "test/common/memory_manager.h"
#include "test/common/random_state.h"
#include "test/common/utils.h"
#include <cnrt.h>
#include <cstdlib>
using namespace megdnn;
using namespace test;
void CAMBRICON::SetUp() {
RandomState::reset();
megcoreDeviceHandle_t dev_handle;
// use card 0
megcore_check(megcoreCreateDeviceHandle(&dev_handle,
megcorePlatformCambricon, 0));
megcoreComputingHandle_t comp_handle;
megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle));
m_handle_cambricon = Handle::make(comp_handle);
megdnn_assert(m_handle_cambricon);
}
Handle* CAMBRICON::handle_naive() {
if (!m_handle_naive)
m_handle_naive = create_cpu_handle(2);
return m_handle_naive.get();
}
void CAMBRICON::TearDown() {
m_handle_naive.reset();
m_handle_cambricon.reset();
MemoryManagerHolder::instance()->clear();
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cambricon/fixture.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <gtest/gtest.h>
#include "test/common/fix_gtest_on_platforms_without_exception.inl"
#include "megcore_cdefs.h"
#include "megdnn/handle.h"
#include <memory>
namespace megdnn {
namespace test {
class CAMBRICON : public ::testing::Test {
public:
void SetUp() override;
void TearDown() override;
Handle* handle_cambricon() { return m_handle_cambricon.get(); }
Handle* handle_naive();
private:
std::unique_ptr<Handle> m_handle_naive;
std::unique_ptr<Handle> m_handle_cambricon;
};
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import List
import megengine._internal as mgb
from ..core import Tensor, wrap_io_tensor
@wrap_io_tensor
def cambricon_subgraph(
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool,
) -> List[Tensor]:
"""Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and
execute the operations defined in the subgraph.
:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
:param symbol: The name of the function in the subgraph.
The function is corresponding to a cnmlFusionOp
which is added to the cnmlModel_t/cnrtModel_t.
:param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe
in cnrtModel_t
"""
return mgb.opr.cambricon_runtime(
data, symbol, tuple(map(lambda x: x._symvar, inputs)), tensor_dim_mutable
)
@wrap_io_tensor
def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
) -> List[Tensor]:
"""Load a serialized extern opr subgraph and fake execute the operator
:param inputs: Tensor or list of input tensors.
:param output_shapes: The output shapes.
:param dump_name: The serialized subgraph name.
:param dump_data: The serialized subgraph.
:return: List of tensors
"""
if not isinstance(inputs, list):
inputs = [inputs]
return mgb.opr.extern_c_opr_placeholder(
inputs, output_shapes, dump_name=dump_name, dump_data=dump_data,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional.external import cambricon_subgraph, extern_opr_subgraph
from .module import Module
class CambriconSubgraph(Module):
r"""Load a serialized Cambricon subgraph.
See :func:`~.cambricon_subgraph` for more details.
"""
def __init__(
self, data, symbol, tensor_dim_mutable,
):
super(CambriconSubgraph, self).__init__()
self._data = data
self.symbol = symbol
self.tensor_dim_mutable = tensor_dim_mutable
@property
def data(self):
return self._data.tobytes()
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, inputs):
outputs = cambricon_subgraph(
inputs, self._data, self.symbol, self.tensor_dim_mutable,
)
return outputs
class ExternOprSubgraph(Module):
r"""Load a serialized extern opr subgraph.
"""
def __init__(self, data, name, output_shapes):
super(ExternOprSubgraph, self).__init__()
self.data = data
self.name = name
self.output_shapes = output_shapes
def forward(self, inputs):
outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,)
return outputs
......@@ -246,6 +246,27 @@ SymbolVarArray _Opr::tensor_rt_runtime(const SymbolVarArray& inputs,
}
#endif
#if MGB_ATLAS
#include "megbrain/opr/atlas_runtime_op.h"
SymbolVarArray _Opr::atlas_runtime(const SymbolVarArray& inputs,
PyObject* data_bytes,
const OperatorNodeConfig& config) {
mgb_assert(PyBytes_Check(data_bytes));
auto size = PyBytes_Size(data_bytes);
mgb_assert(size, "atlas data bytes should not be empty");
return opr::AtlasRuntimeOpr::make(PyBytes_AsString(data_bytes), size,
inputs, config);
}
#else
SymbolVarArray _Opr::atlas_runtime(const SymbolVarArray& inputs,
PyObject* data_bytes,
const OperatorNodeConfig& config) {
mgb_throw(MegBrainError, "Atlas disabled at compile time");
}
#endif
SymbolVar _Opr::timestamp(SymbolVar input, PyObject* dest, size_t dest_off,
const OperatorNodeConfig& config) {
......@@ -266,4 +287,27 @@ SymbolVar _Opr::virtual_dep(const SymbolVarArray& symvars,
}
#if MGB_CAMBRICON
#include "megbrain/cambricon/cambricon_runtime_opr.h"
SymbolVarArray _Opr::cambricon_runtime(PyObject* data_bytes, const char* symbol,
const SymbolVarArray& inputs,
bool tensor_dim_mutable,
const OperatorNodeConfig& config) {
mgb_assert(PyBytes_Check(data_bytes));
auto size = PyBytes_Size(data_bytes);
mgb_assert(size, "cambricon data bytes should not be empty");
return opr::CambriconRuntimeOpr::make(PyBytes_AsString(data_bytes), size,
symbol, inputs, tensor_dim_mutable,
config);
}
#else
SymbolVarArray _Opr::cambricon_runtime(PyObject* data_bytes, const char* symbol,
const SymbolVarArray& inputs,
bool tensor_dim_mutable,
const OperatorNodeConfig& config) {
mgb_throw(MegBrainError, "Cambricon disabled at compile time");
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -130,6 +130,16 @@ static SymbolVar virtual_loss(const SymbolVarArray& ys,
static SymbolVar virtual_dep(const SymbolVarArray& symvars,
const OperatorNodeConfig& config);
static SymbolVarArray atlas_runtime(const SymbolVarArray& inputs,
PyObject* data_bytes,
const OperatorNodeConfig& config);
static SymbolVarArray cambricon_runtime(PyObject* data_bytes,
const char* symbol,
const SymbolVarArray& inputs,
bool tensor_dim_mutable,
const OperatorNodeConfig& config);
#ifdef SWIG
%pythoncode {
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import numpy as np
import megengine as mge
from megengine import tensor
from megengine.module import Module
from megengine.module.external import CambriconSubgraph
class MyModule(Module):
def __init__(self, data):
super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True)
def forward(self, inputs):
out = self.cambricon(inputs)
return out
def test_cambricon_module():
model = "CambriconRuntimeOprTest.MutableBatchSize.mlu"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = MyModule(data)
inputs = []
inputs.append(tensor(dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))
def inference(inps):
pred = m(inps)
return pred
pred = inference(inputs)
......@@ -33,6 +33,14 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT)
list(APPEND SOURCES ${SOURCES_})
endif()
if(MGE_WITH_CAMBRICON)
list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/cambricon/include)
file(GLOB_RECURSE SOURCES_ cambricon/impl/*.cpp cambricon/impl/*.inl)
list(APPEND SOURCES ${SOURCES_})
endif()
set(MGB_CAMBRICON ${MGE_WITH_CAMBRICON})
set(MGB_ATLAS ${MGE_WITH_ATLAS})
if(MGE_WITH_CUDA)
file(GLOB_RECURSE SOURCES_ opr/impl/standalone/*.cu)
......@@ -77,6 +85,8 @@ if(MGE_WITH_DISTRIBUTED)
target_link_libraries (megbrain PRIVATE megray)
endif()
target_link_libraries(megbrain PRIVATE ${MGE_CUDA_LIBS})
target_link_libraries(megbrain PUBLIC ${MGE_CAMBRICON_LIBS})
target_link_libraries(megbrain PUBLIC ${MGE_ATLAS_LIBS})
if(MGE_WITH_JIT AND MGE_WITH_HALIDE)
target_link_libraries(megbrain PRIVATE libhalide)
target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS})
......
/**
* \file src/cambricon/impl/cambricon_runtime_opr.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/cambricon/cambricon_runtime_opr.h"
#include "megbrain/common.h"
#if MGB_CAMBRICON
using namespace mgb;
using namespace opr;
namespace {
SmallVector<int> mgb_shape_to_cnrt_shape(TensorShape mgb_shp) {
int ndim = mgb_shp.ndim;
SmallVector<int> cnrt_shp(ndim);
for (int i = 0; i < ndim; ++i) {
cnrt_shp[i] = mgb_shp[i];
}
return cnrt_shp;
}
TensorShape cnrt_shape_to_mgb_shape(int* dim_values, int dim_num) {
TensorShape ret;
ret.ndim = dim_num;
for (int i = 0; i < dim_num; ++i) {
ret[i] = dim_values[i];
}
return ret;
}
DType cnrt_dtype_to_mgb_dtype(cnrtDataType_t data_type) {
switch (data_type) {
case CNRT_FLOAT16:
#if !MEGDNN_DISABLE_FLOAT16
return dtype::Float16();
#else
mgb_throw(MegBrainError,
"Float16 support is disabled at compile time.");
#endif
case CNRT_FLOAT32:
return dtype::Float32();
case CNRT_INT8:
return dtype::QuantizedS8(1.f);
case CNRT_INT16:
return dtype::Int16();
case CNRT_INT32:
return dtype::Int32();
case CNRT_UINT8:
return dtype::Uint8();
//! TODO: check scale
case CNRT_QUANT8:
return dtype::QuantizedS8(1.f);
default:
mgb_throw(MegBrainError,
"cnrtDataType %x is not supported by MegBrain.",
data_type);
}
}
cnrtDataType_t mgb_dtype_to_cnrt_dtype(DType data_type) {
switch (data_type.enumv()) {
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return CNRT_FLOAT16;
#endif
case DTypeEnum::Float32:
return CNRT_FLOAT32;
case DTypeEnum::QuantizedS8:
return CNRT_QUANT8;
case DTypeEnum::Int32:
return CNRT_INT32;
default:
mgb_throw(MegBrainError,
"megbrain data type %s is not supported by cnrt.",
data_type.name());
}
}
}; // namespace
/* ====================== CambriconRuntimeOpr ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CambriconRuntimeOpr);
CambriconRuntimeOpr::CambriconRuntimeOpr(SharedBuffer buf, std::string symbol,
const VarNodeArray& inputs,
bool tensor_dim_mutable,
const OperatorNodeConfig& config)
: Super(inputs[0]->owner_graph(), config, "cambricon_runtime", inputs),
m_buffer{std::move(buf)},
m_symbol{std::move(symbol)},
m_model{nullptr},
m_function{nullptr},
m_context{nullptr},
m_tensor_dim_mutable{tensor_dim_mutable} {
mgb_assert(inputs[0]->comp_node().device_type() ==
CompNode::DeviceType::CAMBRICON,
"CambriconRuntimeOpr can only be used on cambricon comp node; "
"got %s",
inputs[0]->comp_node().to_string().c_str());
for (auto i : inputs) {
add_input({i});
}
if (m_model == nullptr) {
m_model = {new cnrtModel_t(), cnrt_intl::ModelUnloader()};
MGB_CNRT_CHECK(cnrtLoadModelFromMem(
m_model.get(),
reinterpret_cast<char*>(const_cast<void*>(m_buffer.data()))));
}
if (m_function == nullptr) {
m_function = {new cnrtFunction_t(), cnrt_intl::FunctionDeleter()};
MGB_CNRT_CHECK(cnrtCreateFunction(m_function.get()));
MGB_CNRT_CHECK(cnrtExtractFunction(m_function.get(), *m_model,
m_symbol.c_str()));
}
int nr_inputs = 0;
int nr_outputs = 0;
int64_t* inputs_size = nullptr;
int64_t* outputs_size = nullptr;
MGB_CNRT_CHECK(cnrtGetInputDataSize(&inputs_size, &nr_inputs, *m_function));
mgb_assert(static_cast<size_t>(nr_inputs) == inputs.size(),
"inputs size mismatch: expect=%d, got=%zu", nr_inputs,
inputs.size());
MGB_CNRT_CHECK(
cnrtGetOutputDataSize(&outputs_size, &nr_outputs, *m_function));
if (nr_outputs == 1) {
add_output(None);
} else {
for (int i = 0; i < nr_outputs; ++i) {
add_output(ssprintf("o%d", i));
}
}
add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data());
};
void CambriconRuntimeOpr::scn_do_execute() {
mgb_assert(m_function != nullptr);
auto&& cnrt_env =
CompNodeEnv::from_comp_node(input(0)->comp_node()).cnrt_env();
cnrt_env.activate();
if (m_context == nullptr) {
m_context = {new cnrtRuntimeContext_t(),
cnrt_intl::RuntimeContextDeleter()};
MGB_CNRT_CHECK(cnrtCreateRuntimeContext(m_context.get(), *m_function,
nullptr));
int dev_id = cnrt_env.device;
MGB_CNRT_CHECK(cnrtSetRuntimeContextDeviceId(*m_context, dev_id));
MGB_CNRT_CHECK(cnrtInitRuntimeContext(*m_context, nullptr));
}
size_t nr_inputs = input().size(), nr_outputs = output().size();
SmallVector<void*> params(nr_inputs + nr_outputs);
SmallVector<cnrtParamDesc_t> param_descs(nr_inputs + nr_outputs);
for (size_t i = 0; i < nr_inputs; ++i) {
params[i] = input(i)->dev_tensor().raw_ptr();
MGB_CNRT_CHECK(cnrtCreateParamDesc(&param_descs[i]));
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
param_descs[i], mgb_dtype_to_cnrt_dtype(input(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(input(i)->shape());
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(param_descs[i], dims.data(),
static_cast<int>(dims.size())));
}
for (size_t i = 0; i < nr_outputs; ++i) {
params[nr_inputs + i] = output(i)->dev_tensor().raw_ptr();
MGB_CNRT_CHECK(cnrtCreateParamDesc(&param_descs[nr_inputs + i]));
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
param_descs[nr_inputs + i],
mgb_dtype_to_cnrt_dtype(output(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(output(i)->shape());
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(param_descs[nr_inputs + i],
dims.data(),
static_cast<int>(dims.size())));
}
MGB_CNRT_CHECK(cnrtInvokeRuntimeContext_V2(*m_context, param_descs.data(),
params.data(), cnrt_env.queue,
nullptr));
for (auto& param : param_descs) {
MGB_CNRT_CHECK(cnrtDestroyParamDesc(param));
}
}
void CambriconRuntimeOpr::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
mgb_assert(m_function != nullptr);
mgb_assert(input().size() == inp_shape.size());
if (m_tensor_dim_mutable) {
cnrtParamDescArray_t input_descs, output_descs;
int inp_param_num = input().size();
int out_param_num = output().size();
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&input_descs, inp_param_num));
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&output_descs, out_param_num));
for (int i = 0; i < inp_param_num; ++i) {
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc(
input_descs[i],
mgb_dtype_to_cnrt_dtype(input(i)->dtype())));
auto dims = mgb_shape_to_cnrt_shape(inp_shape[i]);
MGB_CNRT_CHECK(
cnrtSetShapeToParamDesc(input_descs[i], dims.data(),
static_cast<int>(dims.size())));
}
MGB_CNRT_CHECK(cnrtInferFunctionOutputShape(*m_function, inp_param_num,
input_descs, out_param_num,
output_descs));
for (int i = 0; i < out_param_num; ++i) {
int* dims = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetShapeFromParamDesc(output_descs[i], &dims,
&dim_num));
out_shape[i] = cnrt_shape_to_mgb_shape(dims, dim_num);
}
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(input_descs, inp_param_num));
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(output_descs, out_param_num));
} else {
//! check input shape match
for (size_t i = 0; i < inp_shape.size(); ++i) {
int* dim_values = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetInputDataShape(
&dim_values, &dim_num, static_cast<int>(i), *m_function));
auto shp_in_func = cnrt_shape_to_mgb_shape(dim_values, dim_num);
auto inpshp = inp_shape[i];
MGB_MARK_USED_VAR(shp_in_func);
mgb_assert(
inpshp.eq_shape(shp_in_func),
"input shape(%s) mismatch with that(%s) in cnrtFunction_t.",
inpshp.to_string().c_str(),
shp_in_func.to_string().c_str());
}
//! remarks: cnrt does not provide interface to let user manage
//! workspace
MGB_MARK_USED_VAR(mgb_dtype_to_cnrt_dtype);
for (size_t i = 0; i < out_shape.size(); ++i) {
int* dim_values = nullptr;
int dim_num = 0;
MGB_CNRT_CHECK(cnrtGetOutputDataShape(
&dim_values, &dim_num, static_cast<int>(i), *m_function));
out_shape[i] = cnrt_shape_to_mgb_shape(dim_values, dim_num);
}
}
}
void CambriconRuntimeOpr::add_input_layout_constraint() {
//! default contiguous
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
void CambriconRuntimeOpr::init_output_dtype() {
cnrtDataType_t* inp_dtype_array = nullptr;
int inp_num;
MGB_CNRT_CHECK(
cnrtGetInputDataType(&inp_dtype_array, &inp_num, *m_function));
for (size_t i = 0; i < input().size(); ++i) {
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(inp_dtype_array[i]);
auto dt_inp = input(i)->dtype();
MGB_MARK_USED_VAR(dt_cnrt);
MGB_MARK_USED_VAR(dt_inp);
mgb_assert(dt_cnrt.valid() && dt_inp.valid() &&
dt_cnrt.enumv() == dt_inp.enumv(),
"Input %zu's data type mismatch with that in "
"cnrtFunction_t: expected %s, got %s",
i, dt_cnrt.name(), dt_inp.name());
}
cnrtDataType_t* out_dtype_array = nullptr;
int out_num;
MGB_CNRT_CHECK(
cnrtGetOutputDataType(&out_dtype_array, &out_num, *m_function));
for (size_t i = 0; i < output().size(); ++i) {
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(out_dtype_array[i]);
mgb_assert(dt_cnrt.valid(),
"output dtype checking failed: invalid dtype returned.");
if (dt_cnrt.enumv() == DTypeEnum::QuantizedS8) {
mgb_assert(output(i)->dtype().valid(),
"user should specify scale of output tensor of "
"CambriconRuntimeOpr.");
}
if (!output(i)->dtype().valid())
output(i)->dtype(dt_cnrt);
}
}
SymbolVarArray CambriconRuntimeOpr::make(SharedBuffer buf, std::string symbol,
const SymbolVarArray& src,
bool tensor_dim_mutable,
const OperatorNodeConfig& config) {
VarNodeArray var_node_array = cg::to_var_node_array(src);
auto cambricon_runtime_opr = std::make_unique<CambriconRuntimeOpr>(
std::move(buf), std::move(symbol), var_node_array,
tensor_dim_mutable, config);
auto ret = cg::to_symbol_var_array(
src[0].node()
->owner_graph()
->insert_opr(std::move(cambricon_runtime_opr))
->output());
return ret;
}
SymbolVarArray CambriconRuntimeOpr::make(const void* buf, size_t size,
std::string symbol,
const SymbolVarArray& src,
bool tensor_dim_mutable,
const OperatorNodeConfig& config) {
mgb_throw_if(!CompNode::get_device_count(CompNode::DeviceType::CAMBRICON),
SystemError,
"can not create CambriconRuntimeOpr when Cambricon is not "
"available");
std::shared_ptr<uint8_t> shptr{new uint8_t[size],
[](uint8_t* p) { delete[] p; }};
memcpy(shptr.get(), buf, size);
SharedBuffer buffer{std::move(shptr), size};
return make(std::move(buffer), std::move(symbol), src, tensor_dim_mutable,
config);
}
#endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
decl_raw_opr(
'cambricon_runtime',
desc='create an operator that could load and run cnrt offline models',
inputs=[
Doc('data_bytes', 'serialized cnrt/cnml model'),
Doc('symbol', 'name of cnrt/cnml function', 'str'),
Doc('inputs', 'input vars', 'list of :class:`.SymbolVar`'),
Doc('tensor_dim_mutable', 'whether tensor shape is mutable in cnrt/cnml model', 'bool'),
],
body=[
'assert isinstance(data_bytes, bytes), '
'"data must be bytes; got {}".format(type(data_bytes))',
'output = _mgb._Opr.cambricon_runtime(data_bytes, symbol, inputs, tensor_dim_mutable, config)',
'cvt_result_kwargs["explode_single"] = False',
],
)
# vim: ft=python
/**
* \file src/cambricon/impl/cambricon_runtime_opr.sereg.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/cambricon/cambricon_runtime_opr.h"
#include "megbrain/serialization/sereg.h"
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImpl<opr::CambriconRuntimeOpr, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<opr::CambriconRuntimeOpr>();
auto&& buf = opr.buffer();
ctx.dump_buf_with_len(buf.data(), buf.size());
auto&& symbol = opr.symbol();
ctx.dump_buf_with_len(symbol.data(), symbol.size());
bool tensor_dim_mutable = opr.is_tensor_dim_mutable();
ctx.dump_buf_with_len(&tensor_dim_mutable, sizeof(bool));
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
inputs.at(0)->comp_node().activate();
auto buf = ctx.load_shared_buf_with_len();
auto symbol = ctx.load_buf_with_len();
auto tensor_dim_mutable_storage = ctx.load_buf_with_len();
bool tensor_dim_mutable;
memcpy(&tensor_dim_mutable, tensor_dim_mutable_storage.data(),
sizeof(bool));
return opr::CambriconRuntimeOpr::make(std::move(buf), std::move(symbol),
cg::to_symbol_var_array(inputs),
tensor_dim_mutable, config)
.at(0)
.node()
->owner_opr();
}
};
} // namespace serialization
namespace opr {
cg::OperatorNodeBase* opr_shallow_copy_cambricon_runtime_opr(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<CambriconRuntimeOpr>();
return CambriconRuntimeOpr::make(opr.buffer(), opr.symbol(),
cg::to_symbol_var_array(inputs),
opr.is_tensor_dim_mutable(), config)
.at(0)
.node()
->owner_opr();
}
MGB_SEREG_OPR(CambriconRuntimeOpr, 0);
MGB_REG_OPR_SHALLOW_COPY(CambriconRuntimeOpr,
opr_shallow_copy_cambricon_runtime_opr);
} // namespace opr
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/cambricon/include/megbrain/cambricon/cambricon_runtime_opr.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/comp_node_env.h"
#include "megbrain/graph.h"
#include "megbrain/serialization/file.h"
#if MGB_CAMBRICON
namespace mgb {
namespace opr {
namespace cnrt_intl {
struct ModelUnloader {
void operator()(cnrtModel_t* model) {
if (model != nullptr)
MGB_CNRT_CHECK(cnrtUnloadModel(*model));
}
};
struct FunctionDeleter {
void operator()(cnrtFunction_t* function) {
if (function != nullptr)
MGB_CNRT_CHECK(cnrtDestroyFunction(*function));
}
};
struct RuntimeContextDeleter {
void operator()(cnrtRuntimeContext_t* context) {
if (context != nullptr)
MGB_CNRT_CHECK(cnrtDestroyRuntimeContext(*context));
}
};
using CnrtModelUniquePtr = std::unique_ptr<cnrtModel_t, ModelUnloader>;
using CnrtFunctionUniquePtr = std::unique_ptr<cnrtFunction_t, FunctionDeleter>;
using CnrtRuntimeContextUniquePtr =
std::unique_ptr<cnrtRuntimeContext_t, RuntimeContextDeleter>;
}; // namespace cnrt_intl
MGB_DEFINE_OPR_CLASS(CambriconRuntimeOpr, cg::SingleCNOutshapePureByInshapeOprBase) // {
public:
using CnrtModelUniquePtr = cnrt_intl::CnrtModelUniquePtr;
using CnrtFunctionUniquePtr = cnrt_intl::CnrtFunctionUniquePtr;
using CnrtRuntimeContextUniquePtr = cnrt_intl::CnrtRuntimeContextUniquePtr;
using SharedBuffer = mgb::serialization::SharedBuffer;
void scn_do_execute() override;
void get_output_var_shape(const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override;
void add_input_layout_constraint() override;
void init_output_dtype() override;
CambriconRuntimeOpr(SharedBuffer buf, std::string symbol,
const VarNodeArray& inputs, bool tensor_dim_mutable,
const OperatorNodeConfig& config);
const SharedBuffer& buffer() const {
return m_buffer;
}
const std::string& symbol() const {
return m_symbol;
}
bool is_tensor_dim_mutable() const {
return m_tensor_dim_mutable;
}
static SymbolVarArray make(SharedBuffer buf, std::string symbol,
const SymbolVarArray& src,
bool tensor_dim_mutable = false,
const OperatorNodeConfig& config = {});
static SymbolVarArray make(const void* buf, size_t size, std::string symbol,
const SymbolVarArray& src,
bool tensor_dim_mutable = false,
const OperatorNodeConfig& config = {});
private:
SharedBuffer m_buffer;
std::string m_symbol;
CnrtModelUniquePtr m_model;
CnrtFunctionUniquePtr m_function;
CnrtRuntimeContextUniquePtr m_context;
bool m_tensor_dim_mutable;
};
} // namespace opr
} // namespace mgb
#endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
此差异已折叠。
/**
* \file src/core/impl/comp_node/atlas/comp_node.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./comp_node.h"
#include "megbrain/comp_node_env.h"
#include <memory>
#include <string>
using namespace mgb;
#if MGB_ATLAS
#include "megbrain/common.h"
#include "megbrain/comp_node/alloc.h"
#include "megbrain/utils//timer.h"
#include "megcore_atlas.h"
#include <cctype>
#include <cstdio>
#include <acl/acl.h>
#include <limits>
using AtlasCompNodeImpl = AtlasCompNode::CompNodeImpl;
/* ===================== AtlasCompNodeImpl ===================== */
class AtlasCompNode::CompNodeImpl final : public CompNode::Impl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
friend class EventImpl;
friend class AtlasCompNode;
struct DeviceInfo;
struct StaticData;
static StaticData* sd;
static Spinlock sd_mtx;
//! set to true when m_locator is assigned; set to false if async init
//! failed
bool m_initialized = false;
Locator m_locator, m_locator_logical;
DeviceInfo* m_device_info;
std::unique_ptr<Event> m_sync_event;
Spinlock m_sync_event_mtx;
void activate() { m_env.atlas_env().activate(); }
void init(const Locator& locator, const Locator& locator_logical);
void fini();
//! return whether global finalized, and print warning in such case
static inline bool check_global_finalized();
//! enable peer copy from dev0 to dev1
static void enable_peer_access(int dev0, int dev1);
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_device(ptr);
}
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_host(ptr);
}
public:
CompNodeImpl() : Impl(static_free_device, static_free_host) {}
void* alloc_device(size_t size) override {
activate();
void* addr;
MGB_ATLAS_CHECK(aclrtMalloc(&addr, size, ACL_MEM_MALLOC_HUGE_FIRST));
return addr;
}
void free_device(void* ptr) {
if (check_global_finalized())
return;
activate();
MGB_ATLAS_CHECK(aclrtFree(ptr));
}
void* alloc_host(size_t size) override {
void* ptr;
MGB_ATLAS_CHECK(aclrtMallocHost(&ptr, size));
return ptr;
}
void free_host(void* ptr) { MGB_ATLAS_CHECK(aclrtFreeHost(ptr)); }
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
activate();
MGB_ATLAS_CHECK(aclrtMemcpyAsync(host_ptr, size, device_ptr, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
}
void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
activate();
MGB_ATLAS_CHECK(aclrtMemcpy(device_ptr, size, host_ptr, size,
ACL_MEMCPY_HOST_TO_DEVICE));
}
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override;
size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment;
}
std::unique_ptr<Event> create_event(size_t flags) override;
void sync() override;
MemNode mem_node() override;
size_t get_mem_padding() override { return 32; }
std::pair<size_t, size_t> get_mem_status_bytes() override {
return {std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max()};
}
Locator locator() override { return m_locator; }
Locator locator_logical() override { return m_locator_logical; }
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasCompNode::CompNodeImpl);
struct AtlasCompNodeImpl::DeviceInfo {
int dev_num = -1;
void init(const CompNodeEnv& env) {
auto&& atlas_env = env.atlas_env();
atlas_env.activate();
dev_num = atlas_env.device;
}
void fini() {
MGB_ATLAS_CHECK(aclrtResetDevice(dev_num));
}
};
struct AtlasCompNodeImpl::StaticData {
static constexpr int MAX_NR_COMP_NODE = 1024, MAX_NR_DEVICE = 64;
std::recursive_mutex mtx;
AtlasCompNode::CompNodeImpl node[MAX_NR_COMP_NODE];
DeviceInfo dev_info[MAX_NR_DEVICE];
int nr_node = 0, //!< number of loaded node[]
nr_dev_used = 0; //!< number of used dev_info[]
StaticData() {}
~StaticData() {
for (int i = 0; i < nr_node; ++i)
node[i].fini();
for (int i = 0; i < nr_dev_used; ++i)
dev_info[i].fini();
}
};
AtlasCompNodeImpl::StaticData* AtlasCompNodeImpl::sd = nullptr;
Spinlock AtlasCompNodeImpl::sd_mtx;
void AtlasCompNodeImpl::init(const Locator& locator,
const Locator& locator_logical) {
m_locator = locator;
m_locator_logical = locator_logical;
m_initialized = true;
CompNodeEnv::AtlasEnv atlas_env;
atlas_env.device = locator.device;
m_env.init_atlas(make_comp_node_from_impl(this), atlas_env);
DeviceInfo* dev_info = nullptr;
for (int i = 0; i < sd->nr_dev_used; ++i) {
if (sd->dev_info[i].dev_num == locator.device) {
dev_info = &sd->dev_info[i];
break;
}
}
if (!dev_info) {
dev_info = &sd->dev_info[sd->nr_dev_used];
dev_info->init(m_env);
// note: add nr_dev_used only after init succeeds
++sd->nr_dev_used;
}
m_device_info = dev_info;
}
void AtlasCompNodeImpl::fini() {
if (!m_initialized)
return;
m_sync_event.reset();
m_env.fini();
m_initialized = false;
m_device_info = nullptr;
}
void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
const void* src, size_t size) {
if (dest_impl->same_type<AtlasCompNodeImpl>()) {
auto&& dst_env =
static_cast<AtlasCompNodeImpl*>(dest_impl)->m_env.atlas_env();
auto&& src_env = m_env.atlas_env();
activate();
if (dst_env.device == src_env.device) {
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
} else {
mgb_throw(MegBrainError,
"Atlas does not support peer copy between differents "
"device.");
}
return;
}
mgb_assert(dest_impl->env().property().type == DeviceType::CPU,
"cuda peer_copy_to only implemented for CPU");
auto copy = [this, dest, src, size]() {
auto stream = m_env.atlas_env().stream;
m_env.atlas_env().activate();
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
MGB_ATLAS_CHECK(aclrtSynchronizeStream(stream));
};
dest_impl->env().cpu_env().dispatch(copy);
}
MemNode AtlasCompNodeImpl::mem_node() {
// m_device_info would be null before async init finishes; so we just return
// a private pointer related to device number here
return MemNode{sd->dev_info + m_locator.device};
}
void AtlasCompNodeImpl::sync() {
activate();
Event* event;
{
MGB_LOCK_GUARD(m_sync_event_mtx);
if (!m_sync_event)
m_sync_event = create_event(0);
event = m_sync_event.get();
}
event->record();
event->host_wait();
}
void AtlasCompNodeImpl::enable_peer_access(int dev0, int dev1) {
MGB_MARK_USED_VAR(dev0);
MGB_MARK_USED_VAR(dev1);
mgb_throw(MegBrainError,
"Atlas does not support peer copy between differents "
"device.");
}
bool AtlasCompNodeImpl::check_global_finalized() {
if (!sd) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug(
"atlas comp node method called after global finalize");
}
return true;
}
return false;
}
/* ===================== EventImpl ===================== */
/**
* \warning Current we just use cpu timer to do record, later when the api of
* ddk is ready, we change to normal event.
*/
class AtlasCompNode::EventImpl final : public EventImplHelper {
AtlasCompNodeImpl* const m_comp_node_impl;
aclrtEvent m_atlas_event;
bool m_init_finished = false;
void do_record() override {
m_comp_node_impl->activate();
auto &&env = m_comp_node_impl->m_env.atlas_env();
MGB_ATLAS_CHECK(aclrtRecordEvent(m_atlas_event, env.stream));
}
bool do_finished() override {
m_comp_node_impl->activate();
aclrtEventStatus status;
MGB_ATLAS_CHECK(aclrtQueryEvent(m_atlas_event, &status));
if (status == ACL_EVENT_STATUS_COMPLETE)
return true;
if (status == ACL_EVENT_STATUS_NOT_READY)
return false;
mgb_throw(AtlasError, "invalid event status: %d", int(status));
}
void host_wait_cv() override {
MGB_ATLAS_CHECK(aclrtSynchronizeEvent(m_atlas_event));
}
double do_elapsed_time_until(EventImplHelper& end) override {
m_comp_node_impl->activate();
float ret = 0.0;
MGB_ATLAS_CHECK(aclrtEventElapsedTime(&ret, m_atlas_event,
static_cast<EventImpl&>(end).m_atlas_event));
return static_cast<double>(ret) * 1e-3;
}
void do_device_wait_by(Impl* cn_impl) override;
public:
EventImpl(AtlasCompNodeImpl* comp_node_impl, size_t create_flags)
: EventImplHelper(comp_node_impl, create_flags),
m_comp_node_impl{comp_node_impl} {
m_comp_node_impl->activate();
MGB_ATLAS_CHECK(aclrtCreateEvent(&m_atlas_event));
m_init_finished = true;
}
~EventImpl() {
if (m_init_finished) {
MGB_TRY { MGB_ATLAS_CHECK(aclrtDestroyEvent(m_atlas_event)); }
MGB_CATCH(MegBrainError & exc, {
mgb_log_error("failed to destroy cuda event: %s", exc.what());
})
}
}
};
std::unique_ptr<CompNode::Event> AtlasCompNodeImpl::create_event(size_t flags) {
return std::make_unique<EventImpl>(this, flags);
}
void AtlasCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) {
if (cn_impl->dyn_typeinfo() == AtlasCompNodeImpl::typeinfo()) {
auto imp = static_cast<AtlasCompNodeImpl*>(cn_impl);
auto stream = imp->m_env.atlas_env().stream;
imp->activate();
MGB_ATLAS_CHECK(aclrtStreamWaitEvent(stream, m_atlas_event));
return;
}
if (cn_impl->env().property().type == DeviceType::CPU) {
auto waiter = [this]() {
MGB_ATLAS_CHECK(aclrtSynchronizeEvent(m_atlas_event));
};
cn_impl->add_callback(std::move(waiter));
return;
}
mgb_throw(MegBrainError, "unimplemented event device_wait_by config");
}
/* ===================== AtlasCompNode static methods ===================== */
bool AtlasCompNode::available() {
return true;
}
void AtlasCompNode::finalize() {
if (AtlasCompNodeImpl::sd) {
sync_all();
auto ptr = AtlasCompNodeImpl::sd;
AtlasCompNodeImpl::sd = nullptr;
ptr->~StaticData();
}
}
CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator,
const Locator& locator_logical) {
auto&& sdptr = AtlasCompNodeImpl::sd;
{
MGB_LOCK_GUARD(AtlasCompNodeImpl::sd_mtx);
if (!sdptr) {
// use static storage so object can be safely accessed even after
// global finalize
using T = AtlasCompNodeImpl::StaticData;
static std::aligned_storage_t<sizeof(T), alignof(T)> storage;
sdptr = new (&storage) T;
}
}
auto&& sd = *sdptr;
MGB_LOCK_GUARD(sd.mtx);
CompNodeImpl* available_node = nullptr;
for (int i = 0; i < sd.nr_node; ++i) {
auto&& cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
available_node = &cur;
}
}
if (!available_node) {
mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE,
"too many CompNode allocated");
mgb_assert(locator.device < sd.MAX_NR_COMP_NODE,
"device number too large");
available_node = &sd.node[sd.nr_node++];
}
mgb_assert(!available_node->m_initialized);
available_node->init(locator, locator_logical);
log_comp_node_created(locator, locator_logical);
return available_node;
}
void AtlasCompNode::sync_all() {
auto sd = AtlasCompNodeImpl::sd;
if (!sd)
return;
for (int i = 0;; ++i) {
// ensure async init finished
CompNodeEnv* env;
{
MGB_LOCK_GUARD(sd->mtx);
if (i >= sd->nr_node) {
break;
}
env = &sd->node[i].env();
}
env->atlas_env();
}
MGB_LOCK_GUARD(sd->mtx);
MGB_ATLAS_CHECK(aclrtSynchronizeDevice());
}
void AtlasCompNode::foreach (thin_function<void(CompNode)> callback) {
auto sd = AtlasCompNodeImpl::sd;
if (!sd)
return;
for (int i = 0;; ++i) {
CompNode cur;
{
MGB_LOCK_GUARD(sd->mtx);
if (i >= sd->nr_node)
return;
cur = make_comp_node_from_impl(&sd->node[i]);
}
callback(cur);
}
}
size_t AtlasCompNode::get_device_count() {
static uint32_t cnt = 0;
static Spinlock mtx;
MGB_LOCK_GUARD(mtx);
if (cnt == 0) {
uint32_t dev_cnt = 0;
auto ret = aclrtGetDeviceCount(&dev_cnt);
if (ret != ACL_ERROR_NONE) {
mgb_log_error("aclrtGetDeviceCountfaild: %s (err %d)",
::megcore::atlas::get_error_str(ret),
static_cast<int>(ret));
cnt = 0;
}
cnt = dev_cnt;
}
return cnt;
}
#else
bool AtlasCompNode::available() {
return false;
}
void AtlasCompNode::foreach (thin_function<void(CompNode)>) {}
void AtlasCompNode::finalize() {}
size_t AtlasCompNode::get_device_count() {
return 0;
}
AtlasCompNode::Impl* AtlasCompNode::load_atlas(const Locator&, const Locator&) {
mgb_throw(MegBrainError, "atlas disabled at compile time");
}
void AtlasCompNode::sync_all() {}
#endif // MGB_ATLAS
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/core/impl/comp_node/atlas/comp_node.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <mutex>
#include "../impl_helper.h"
namespace mgb {
class AtlasCompNode final : public CompNodeImplHelper {
public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
class CompNodeImpl;
class EventImpl;
//! whether cuda comp node is available
static bool available();
static void foreach (thin_function<void(CompNode)> callback);
static void finalize();
static size_t get_device_count();
static Impl* load_atlas(const Locator& locator,
const Locator& locator_logical);
static void sync_all();
};
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
此差异已折叠。
/**
* \file src/core/impl/comp_node/cambricon/comp_node.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../impl_helper.h"
namespace mgb {
class CambriconCompNode final: public CompNodeImplHelper {
public:
static constexpr Flag sm_flag =
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
class CompNodeImpl;
class EventImpl;
//! whether cambricon comp node is available
static bool available();
static void try_coalesce_all_free_memory();
static void foreach(thin_function<void(CompNode)> callback);
static void finalize();
static size_t get_device_count();
static Impl* load_cambricon(
const Locator &locator, const Locator &locator_logical);
static void sync_all();
};
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -15,6 +15,8 @@
#include "./cuda/comp_node.h"
#include "./cpu/comp_node.h"
#include "./cambricon/comp_node.h"
#include "./atlas/comp_node.h"
#include <cstring>
#include <atomic>
......@@ -40,6 +42,10 @@ namespace {
return "gpu";
case DT::CPU:
return "cpu";
case DT::ATLAS:
return "atlas";
case DT::CAMBRICON:
return "cambricon";
case DT::MULTITHREAD:
return "multithread";
default:
......@@ -145,7 +151,20 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
DeviceType dev_type;
// parse dev_type
if (ptr[0] == 'm') {
if (ptr[0] == 'a') {
if (strncmp(ptr, "atlas", 5)) {
err();
}
dev_type = DeviceType::ATLAS;
ptr += 5;
}
else if (ptr[2] == 'm') {
if (strncmp(ptr, "cambricon", 9)) {
err();
}
dev_type = DeviceType::CAMBRICON;
ptr += 9;
} else if (ptr[0] == 'm') {
if (strncmp(ptr, "multithread", 11)) {
err();
}
......@@ -478,6 +497,13 @@ CompNode CompNode::load(const Locator& locator_physical,
case DeviceType::CPU:
ret = CpuCompNode::load_cpu(locator_physical, locator_logical);
break;
case DeviceType::ATLAS:
ret = AtlasCompNode::load_atlas(locator_physical, locator_logical);
break;
case DeviceType::CAMBRICON:
ret = CambriconCompNode::load_cambricon(locator_physical,
locator_logical);
break;
default:
mgb_throw(MegBrainError, "bad device type");
}
......@@ -496,20 +522,27 @@ void CompNode::finalize() {
comp_node_detail::DepedentObjList::invoke_callback_and_clean();
CudaCompNode::finalize();
CpuCompNode::finalize();
CambriconCompNode::finalize();
AtlasCompNode::finalize();
}
void CompNode::try_coalesce_all_free_memory() {
CudaCompNode::try_coalesce_all_free_memory();
CambriconCompNode::try_coalesce_all_free_memory();
}
void CompNode::sync_all() {
CudaCompNode::sync_all();
CpuCompNode::sync_all();
CambriconCompNode::sync_all();
AtlasCompNode::sync_all();
}
void CompNode::foreach(thin_function<void(CompNode)> callback) {
CudaCompNode::foreach(callback);
CpuCompNode::foreach(callback);
CambriconCompNode::foreach(callback);
AtlasCompNode::foreach(callback);
}
size_t CompNode::get_device_count(DeviceType type, bool warn) {
......@@ -519,6 +552,10 @@ size_t CompNode::get_device_count(DeviceType type, bool warn) {
case DeviceType::MULTITHREAD:
case DeviceType::CPU:
return CpuCompNode::get_device_count();
case DeviceType::CAMBRICON:
return CambriconCompNode::get_device_count();
case DeviceType::ATLAS:
return AtlasCompNode::get_device_count();
default:
mgb_throw(MegBrainError, "bad device type");
}
......@@ -534,6 +571,12 @@ bool CompNode::contain_flag(DeviceType device_type, Flag flag) {
case DeviceType::CPU:
cn_flag = CpuCompNode::sm_flag;
break;
case DeviceType::CAMBRICON:
cn_flag = CambriconCompNode::sm_flag;
break;
case DeviceType::ATLAS:
cn_flag = AtlasCompNode::sm_flag;
break;
default:
mgb_throw(MegBrainError, "unexpected device type");
}
......
......@@ -528,9 +528,23 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
Impl *dest_impl, void *dest,
const void *src, size_t size) override {
if (!dest_impl->same_type<CpuCompNode::CompNodeImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but "
"MGB_ATLAS not enabled");
#endif
} else {
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp nodes "
"is implemented");
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device(dest, src, size);
}
......@@ -841,12 +855,22 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
auto type = cn_impl->env().property().type;
mgb_throw_if(type != CompNode::DeviceType::CPU
&& type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS
,
MegBrainError,
"currently CPU can only wait for CPU, CUDA"
"currently CPU can only wait for CPU, CUDA, ATLAS"
);
}
if (cn_impl->env().property().type == CompNode::DeviceType::ATLAS) {
#if MGB_ATLAS
return m_comp_node_impl->sync();
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but MGB_ATLAS not enabled");
#endif
}
auto version = m_record_nr_req.load(std::memory_order_relaxed);
mgb_assert(version, "device wait on non-recorded event");
......
......@@ -22,6 +22,15 @@
#endif
#endif
#if MGB_CAMBRICON
#include "megcore_cambricon.h"
#endif
#if MGB_ATLAS
#include "acl/acl.h"
#include "megcore_atlas.h"
#endif
using namespace mgb;
/* =================== MegDNNHandle =================== */
......@@ -54,6 +63,28 @@ MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) {
init = true;
}
#endif
#if MGB_CAMBRICON
if (env.property().type == CompNode::DeviceType::CAMBRICON) {
CompNodeEnv::CnrtEnv::init_status.init();
megcore::createDeviceHandleWithGlobalInitStatus(
&m_dev_hdl, env.cnrt_env().device, 0, true);
megcore::createComputingHandleWithCambriconContext(
&m_comp_hdl, m_dev_hdl, 0, {env.cnrt_env().queue});
init = true;
}
#endif
#if MGB_ATLAS
if (env.property().type == CompNode::DeviceType::ATLAS) {
CompNodeEnv::AtlasEnv::init_status.init();
megcore::createAtlasDeviceHandleWithGlobalInitStatus(
&m_dev_hdl, env.atlas_env().device, 0, true);
megcore::createComputingHandleWithAtlasContext(
&m_comp_hdl, m_dev_hdl, 0, {env.atlas_env().stream});
init = true;
}
#endif
if (env.property().type == CompNode::DeviceType::CPU) {
megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU);
......@@ -175,6 +206,73 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
}
#endif
#if MGB_ATLAS
void mgb::_on_atlas_error(const char* expr, int err, const char* file,
const char* func, int line) {
mgb_throw(AtlasError, "atlas error %d: %s (%s at %s:%s:%d)", int(err),
megcore::atlas::get_error_str(err), expr, file, func, line);
}
CompNodeEnv::AtlasEnv::InitStatus CompNodeEnv::AtlasEnv::init_status;
void CompNodeEnv::init_atlas(CompNode comp_node, const AtlasEnv& env) {
m_comp_node = comp_node;
m_atlas_env = env;
m_property.type = DeviceType::ATLAS;
m_property.mem_alignment = 64;
m_atlas_env.activate();
MGB_ATLAS_CHECK(aclrtCreateStream(&m_atlas_env.stream));
m_user_data_container = std::make_unique<UserDataContainer>();
mgb_assert(m_property.mem_alignment ==
MegDNNHandle::get(*this).handle()->alignment_requirement());
}
#endif
#if MGB_CAMBRICON
const char* mgb::cnml_get_error_string(cnmlStatus_t err) {
switch (err) {
#define cb(_err) \
case _err: \
return #_err
cb(CNML_STATUS_SUCCESS);
cb(CNML_STATUS_NODEVICE);
cb(CNML_STATUS_DOMAINERR);
cb(CNML_STATUS_INVALIDARG);
cb(CNML_STATUS_LENGTHERR);
cb(CNML_STATUS_OUTOFRANGE);
cb(CNML_STATUS_RANGEERR);
cb(CNML_STATUS_OVERFLOWERR);
cb(CNML_STATUS_UNDERFLOWERR);
cb(CNML_STATUS_INVALIDPARAM);
cb(CNML_STATUS_BADALLOC);
cb(CNML_STATUS_BADTYPEID);
cb(CNML_STATUS_BADCAST);
cb(CNML_STATUS_UNSUPPORT);
#undef cb
}
return "Unknown CNML error";
}
void mgb::_on_cnrt_error(const char* expr, cnrtRet_t err, const char* file,
const char* func, int line) {
mgb_throw(CnrtError, "cnrt error %d: %s (%s at %s:%s:%d)", int(err),
cnrtGetErrorStr(err), expr, file, func, line);
}
void mgb::_on_cndev_error(const char* expr, cndevRet_t err, const char* file,
const char* func, int line) {
mgb_throw(CndevError, "cndev error %d: %s (%s at %s:%s:%d)", int(err),
cndevGetErrorString(err), expr, file, func, line);
}
void mgb::_on_cnml_error(const char* expr, cnmlStatus_t err, const char* file,
const char* func, int line) {
mgb_throw(CnmlError, "cnml error %d: %s (%s at %s:%s:%d)", int(err),
cnml_get_error_string(err), expr, file, func, line);
}
#endif
void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) {
m_comp_node = comp_node;
......@@ -188,6 +286,41 @@ void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) {
}
#if MGB_CAMBRICON
void CompNodeEnv::init_cnrt(int dev, CompNode comp_node,
const ContinuationCtx<cnrtQueue_t>& cont) {
m_comp_node = comp_node;
m_cnrt_env.device = dev;
m_property.type = DeviceType::CAMBRICON;
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev));
// FIXME: doc doesn't describe the aligment requirement for device memory
// address
m_property.mem_alignment = 1u;
// ensure exception safe
bool queue_created = false;
MGB_MARK_USED_VAR(queue_created);
MGB_TRY {
m_cnrt_env.activate();
MGB_CNRT_CHECK(cnrtCreateQueue(&m_cnrt_env.queue));
queue_created = true;
m_user_data_container = std::make_unique<UserDataContainer>();
cont.next(m_cnrt_env.queue);
// TODO: initialize megdnn handle
mgb_assert(m_property.mem_alignment ==
MegDNNHandle::get(*this).handle()->alignment_requirement());
}
MGB_CATCH(std::exception & exc, {
mgb_log_error("cnrt init failed: %s", exc.what());
if (queue_created) {
MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue));
}
cont.err(exc);
throw;
})
}
CompNodeEnv::CnrtEnv::InitStatus CompNodeEnv::CnrtEnv::init_status;
#endif
void CompNodeEnv::fini() {
ensure_async_init_finished();
m_user_data_container.reset();
......@@ -197,6 +330,19 @@ void CompNodeEnv::fini() {
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream));
}
#endif
#if MGB_CAMBRICON
if (m_property.type == DeviceType::CAMBRICON) {
m_cnrt_env.activate();
MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue));
}
#endif
#if MGB_ATLAS
if (m_property.type == DeviceType::ATLAS) {
m_atlas_env.activate();
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream));
}
#endif
}
#if MGB_ENABLE_COMP_NODE_ASYNC_INIT
......
......@@ -71,6 +71,29 @@ std::string CudaError::get_cuda_extra_info() {
#endif
}
AtlasError::AtlasError(const std::string &msg):
SystemError(msg)
{
}
CnrtError::CnrtError(const std::string& msg) : SystemError(msg) {
m_msg.append(get_cnrt_extra_info());
}
std::string CnrtError::get_cnrt_extra_info() {
#if MGB_CAMBRICON
// get last error
auto err = cnrtGetLastErr();
return ssprintf("(last_err=%d(%s))", err, cnrtGetErrorStr(err));
#else
return "cnrt disabled at compile time";
#endif
}
CndevError::CndevError(const std::string& msg) : SystemError(msg) {}
CnmlError::CnmlError(const std::string& msg) : SystemError(msg) {}
bool mgb::has_uncaught_exception() {
#if MGB_ENABLE_EXCEPTION
......
......@@ -124,7 +124,7 @@ StaticDeviceMemoryManager::make_default_impl() {
#endif // MGB_THREAD_SAFE
/* ==================== AsyncVarReleaser ==================== */
#if MGB_CUDA
#if MGB_CUDA || MGB_ATLAS
class VarNodeMemManager::AsyncVarReleaser {
struct WaiterParam {
CompNode cn;
......@@ -247,7 +247,7 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() {
VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph):
m_owner_graph(graph),
m_seq_mem_opt(graph)
#if MGB_CUDA
#if MGB_CUDA || MGB_ATLAS
,m_asyn_var_releaser(new AsyncVarReleaser)
#endif
{
......@@ -255,7 +255,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph):
MGB_MARK_USED_VAR(ev);
// async release is only used for sync between multiple comp nodes, and
// does not wait for device to finish
#if MGB_CUDA
#if MGB_CUDA || MGB_ATLAS
m_asyn_var_releaser->wait_release_finish();
#endif
m_cpu_async_release_barrier.wait_zero();
......@@ -296,8 +296,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph):
graph->event().register_receiver_permanent<event::CompSeqExecError>(
on_comp_seq_error);
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && (MGB_CUDA \
)
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && (MGB_CUDA || MGB_ATLAS)
auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) {
m_asyn_var_releaser->wait_release_finish();
};
......@@ -1350,6 +1349,13 @@ void VarNodeMemManager::decr_var_mem_refcnt(
case DT::CUDA:
m_asyn_var_releaser->add(dispatch_cn, var);
break;
#endif
#if MGB_ATLAS
case DT::ATLAS:
{
m_asyn_var_releaser->add(dispatch_cn, var);
break;
}
#endif
default:
mgb_throw(MegBrainError,
......
......@@ -437,7 +437,7 @@ class VarNodeMemManager {
SyncableCounter m_cpu_async_release_barrier;
#if MGB_CUDA
#if MGB_CUDA || MGB_ATLAS
//! release dynamic var on after compnode event finishes
class AsyncVarReleaser;
std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser;
......
......@@ -611,6 +611,12 @@ void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
MGB_CUDA_CHECK(
cudaMemsetAsync(ptr, val, size, env.cuda_env().stream));
break;
#endif
#if MGB_ATLAS
case CompNode::DeviceType::ATLAS:
MGB_ATLAS_CHECK(aclrtMemsetAsync(ptr, -1, val, size,
env.atlas_env().stream));
break;
#endif
case CompNode::DeviceType::CPU: {
auto fill = [ptr, size, val]() { std::memset(ptr, val, size); };
......
......@@ -112,6 +112,8 @@ class CompNode {
CUDA = 1,
CPU = 2,
CAMBRICON = 3,
ATLAS = 9,
MULTITHREAD,
MAX_DEVICE_ID,
};
......
......@@ -139,6 +139,32 @@ public:
CudaError(const std::string& msg);
};
class AtlasError final: public SystemError {
public:
AtlasError(const std::string& msg);
};
class CnrtError final : public SystemError {
public:
/*!
* \brief get extra info for current cnrt status, to be appended in
* error message
*/
static std::string get_cnrt_extra_info();
CnrtError(const std::string& msg);
};
class CndevError final : public SystemError {
public:
CndevError(const std::string& msg);
};
class CnmlError final : public SystemError {
public:
CnmlError(const std::string& msg);
};
class AssertionError final : public MegBrainError {
public:
......
此差异已折叠。
......@@ -13,6 +13,8 @@
#define _HEADER_MGB_BUILD_CONFIG
#cmakedefine01 MGB_CUDA
#cmakedefine01 MGB_CAMBRICON
#cmakedefine01 MGB_ATLAS
#cmakedefine01 MGB_ASSERT_LOC
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL
#cmakedefine01 MGB_ENABLE_LOGGING
......@@ -54,6 +56,10 @@
#cmakedefine01 MEGDNN_THREADS_512
#cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS
// whether atlas is available
#ifndef MGB_ATLAS
#define MGB_ATLAS 0
#endif
// whether cuda is available
#ifndef MGB_CUDA
......@@ -135,6 +141,15 @@
#endif
#ifndef MEGDNN_WITH_CAMBRICON
#define MEGDNN_WITH_CAMBRICON 0
#endif
#ifndef MGB_CAMBRICON
#define MGB_CAMBRICON MEGDNN_WITH_CAMBRICON
#endif
// whether to enable TensorRT support
#ifndef MGB_ENABLE_TENSOR_RT
#define MGB_ENABLE_TENSOR_RT MGB_CUDA
......
此差异已折叠。
decl_raw_opr(
'atlas_runtime',
desc='create an operator that could load and run acl offline model',
inputs=[
Doc('inputs', 'input vars', 'list of :class:`.SymbolVar`'),
Doc('data_bytes', 'serialized acl model'),
],
body=[
'assert isinstance(data_bytes, bytes), '
'"data must be bytes; got {}".format(type(data_bytes))',
'output = _mgb._Opr.atlas_runtime(inputs, data_bytes, config)',
'cvt_result_kwargs["explode_single"] = False',
],
)
# vim: ft=python
/**
* \file src/opr/impl/atlas_runtime_opr.sereg.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/atlas_runtime_op.h"
#include "megbrain/serialization/sereg.h"
#if MGB_ATLAS
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImpl<opr::AtlasRuntimeOpr, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<opr::AtlasRuntimeOpr>();
auto&& buf = opr.buffer();
ctx.dump_buf_with_len(buf.data(), buf.size());
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
inputs.at(0)->comp_node().activate();
auto buf = ctx.load_shared_buf_with_len();
return opr::AtlasRuntimeOpr::make(
std::move(buf), cg::to_symbol_var_array(inputs),
config)
.at(0)
.node()
->owner_opr();
}
};
} // namespace serialization
namespace opr {
cg::OperatorNodeBase* opr_shallow_copy_atlas_runtime_opr(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(ctx);
auto&& opr = opr_.cast_final_safe<AtlasRuntimeOpr>();
return AtlasRuntimeOpr::make(opr.buffer(), opr.model(),
cg::to_symbol_var_array(inputs),
config)
.at(0)
.node()
->owner_opr();
}
MGB_SEREG_OPR(AtlasRuntimeOpr, 0);
MGB_REG_OPR_SHALLOW_COPY(AtlasRuntimeOpr, opr_shallow_copy_atlas_runtime_opr);
} // namespace opr
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -118,6 +118,8 @@ void run_test(const PluginMaker& plugin_maker,
const ResultChecker& result_checker) {
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) {
auto type = static_cast<CompNode::DeviceType>(i);
if (!check_device_type_avaiable(type))
continue;
if (CompNode::get_device_count(type)) {
auto cn = CompNode::load({type, -1, 0});
if (cn.contain_flag(CompNode::Flag::SUPPORT_RECORDER)) {
......@@ -188,4 +190,3 @@ TEST(TestOprIODump, Binary) {
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -32,6 +32,12 @@ namespace mgb{void call_sereg(){}}
#if MGB_ENABLE_TENSOR_RT
#include "../../tensorrt/impl/tensorrt_opr.sereg.h"
#endif
#if MGB_ATLAS
#include "../../opr/impl/atlas_runtime_op.sereg.h"
#endif
#if MGB_JIT
#include "../../jit/impl/jit.sereg.h"
#endif
#if MGB_CAMBRICON
#include "../../cambricon/impl/cambricon_runtime_opr.sereg.h"
#endif
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册