提交 7f857bd4 编写于 作者: M Megvii Engine Team

feat(mgb/rocm): add cmake for rocm and fix compile errors and bn

GitOrigin-RevId: c73ed4adc37ecaf71832dd6dfb990405fb1873ab
上级 199eefbd
...@@ -52,6 +52,7 @@ option(MGE_BUILD_SDK "Build load_and_run" ON) ...@@ -52,6 +52,7 @@ option(MGE_BUILD_SDK "Build load_and_run" ON)
option(MGE_INFERENCE_ONLY "Build inference only library." OFF) option(MGE_INFERENCE_ONLY "Build inference only library." OFF)
option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON)
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON)
option(MGE_WITH_ROCM "Enable ROCM support" OFF)
if (APPLE) if (APPLE)
set (BUILD_SHARED_LIBS OFF) set (BUILD_SHARED_LIBS OFF)
...@@ -442,6 +443,10 @@ if(MGE_WITH_CAMBRICON) ...@@ -442,6 +443,10 @@ if(MGE_WITH_CAMBRICON)
set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}") set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}")
endif() endif()
if (MGE_WITH_ROCM)
include(cmake/rocm.cmake)
endif ()
if(MGE_WITH_ATLAS) if(MGE_WITH_ATLAS)
include(cmake/aclrt.cmake) include(cmake/aclrt.cmake)
......
if(NOT DEFINED HIP_PATH)
if(NOT DEFINED ENV{HIP_PATH})
set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed")
else()
set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed")
endif()
endif()
set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH})
find_package(HIP QUIET)
if (HIP_FOUND)
message(STATUS "Found HIP: " ${HIP_VERSION})
else()
message(FATAL_ERROR "Could not find HIP. Ensure that HIP is either installed in /opt/rocm/hip or the variable HIP_PATH is set to point to the right location.")
endif()
string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION})
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR)
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR)
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3")
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.")
endif()
if (NOT ${HIP_VERSION_MINOR} STREQUAL "7")
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.")
endif()
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand)
set(HIP_INCLUDE_DIR ${HIP_ROOT_DIR}/../include)
set(HIP_LIBRARY_DIR ${HIP_ROOT_DIR}/../lib)
#miopen
get_filename_component(__found_miopen_library ${HIP_ROOT_DIR}/../miopen/lib REALPATH)
find_path(MIOPEN_LIBRARY_DIR
NAMES libMIOpen.so
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_library}
PATH_SUFFIXES lib
DOC "Path to MIOPEN library directory." )
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find MIOPEN Library")
endif()
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH)
find_path(MIOPEN_INCLUDE_DIR
NAMES miopen
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_include}
PATH_SUFFIXES include
DOC "Path to MIOPEN include directory." )
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find MIOEPN INCLUDE")
endif()
#rocblas
get_filename_component(__found_rocblas_library ${HIP_ROOT_DIR}/../rocblas/lib REALPATH)
find_path(ROCBLAS_LIBRARY_DIR
NAMES librocblas.so
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_library}
PATH_SUFFIXES lib
DOC "Path to ROCBLAS library directory." )
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCBLAS Library")
endif()
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH)
find_path(ROCBLAS_INCLUDE_DIR
NAMES rocblas.h
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_include}
PATH_SUFFIXES include
DOC "Path to ROCBLAS include directory." )
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE")
endif()
#rocrand
get_filename_component(__found_rocrand_library ${HIP_ROOT_DIR}/../rocrand/lib REALPATH)
find_path(ROCRAND_LIBRARY_DIR
NAMES librocrand.so
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_library}
PATH_SUFFIXES lib
DOC "Path to ROCRAND library directory." )
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCRAND Library")
endif()
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH)
find_path(ROCRAND_INCLUDE_DIR
NAMES rocrand.h
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_include}
PATH_SUFFIXES include
DOC "Path to ROCRAND include directory." )
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCRAND INCLUDE")
endif()
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "hip_header.h" #include "hip_header.h"
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
#include <atomic>
namespace megcore { namespace megcore {
struct ROCMContext { struct ROCMContext {
hipStream_t stream = nullptr; hipStream_t stream = nullptr;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if !defined(__CUDACC__) #if !defined(__CUDACC__) && !defined(__HIPCC__)
#endif // !defined(__CUDACC__) #endif // !defined(__CUDACC__)
......
...@@ -103,10 +103,12 @@ namespace megdnn { ...@@ -103,10 +103,12 @@ namespace megdnn {
* \brief iterate through each dtype object that can be involved in float * \brief iterate through each dtype object that can be involved in float
* numeric computing * numeric computing
*/ */
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
cb(::megdnn::dtype::Float32) \ cb(::megdnn::dtype::Float32) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16))
/*! /*!
* \brief iterate through each dtype object that can be involved in integer * \brief iterate through each dtype object that can be involved in integer
......
...@@ -2809,6 +2809,7 @@ namespace std { ...@@ -2809,6 +2809,7 @@ namespace std {
/// Numeric limits for bfloat16-precision floats. /// Numeric limits for bfloat16-precision floats.
/// Because of the underlying single-precision implementation of many /// Because of the underlying single-precision implementation of many
/// operations, it inherits some properties from `numeric_limits<float>`. /// operations, it inherits some properties from `numeric_limits<float>`.
#if !defined(__HIPCC__)
template <> template <>
class numeric_limits<half_bfloat16::bfloat16> : public numeric_limits<float> { class numeric_limits<half_bfloat16::bfloat16> : public numeric_limits<float> {
public: public:
...@@ -2932,6 +2933,7 @@ public: ...@@ -2932,6 +2933,7 @@ public:
0x0001); 0x0001);
} }
}; };
#endif
#ifdef MEGDNN_CC_HOST #ifdef MEGDNN_CC_HOST
#if HALF_ENABLE_CPP11_HASH #if HALF_ENABLE_CPP11_HASH
......
...@@ -37,6 +37,66 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") ...@@ -37,6 +37,66 @@ if(NOT ${MGE_ARCH} STREQUAL "naive")
endif() endif()
###############################################################################
# HIP_COMPILE
###############################################################################
macro (HIP_COMPILE _hip_target _hip_objs)
# Separate the sources from the options
HIP_GET_SOURCES_AND_OPTIONS(_sources
_cmake_options
_hipcc_options
_hcc_options
_nvcc_options
${ARGN})
HIP_PREPARE_TARGET_COMMANDS(${_hip_target}
OBJ _generated_files _source_files ${_sources} ${_cmake_options}
HIPCC_OPTIONS ${_hipcc_options}
HCC_OPTIONS ${_hcc_options}
NVCC_OPTIONS ${_nvcc_options})
if(_source_files)
list(REMOVE_ITEM _sources ${_source_files})
endif()
add_custom_target(${_hip_target})
# set return value
set (${_hip_objs} ${_generated_files})
endmacro()
if (MGE_WITH_ROCM)
file (GLOB_RECURSE SOURCES_ rocm/*.cpp)
list (APPEND SOURCES ${SOURCES_})
# FIXME rocm may lost the first hip file, so currently we just create an
# empty file to bypass this error.
file(GLOB start.cpp.hip "" )
list(APPEND HIP_SOURCES start.cpp.hip)
file (GLOB_RECURSE HIPSOURCES rocm/*.cpp.hip)
set(HIP_TARGET_NAME hip_kernel)
set(_HIPCC_OPTIONS "-fPIC")
set(_HCC_OPTIONS "-fPIC")
set(_NVCC_OPTIONS "-fPIC")
list(APPEND HIP_SOURCES ${HIPSOURCES})
set_source_files_properties(${HIP_SOURCES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
HIP_INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/dnn
${PROJECT_SOURCE_DIR}/dnn/include
${PROJECT_BINARY_DIR}/dnn
${PROJECT_BINARY_DIR}/genfiles
${PROJECT_BINARY_DIR}/dnn/include
${HIP_INCLUDE_DIR}
${MIOPEN_INCLUDE_DIR}
${ROCBLAS_INCLUDE_DIR}
${ROCRAND_INCLUDE_DIR})
hip_compile(
${HIP_TARGET_NAME} HIPOBJS ${HIP_SOURCES}
HIPCC_OPTIONS ${_HIPCC_OPTIONS}
HCC_OPTIONS ${_HCC_OPTIONS}
NVCC_OPTIONS ${_NVCC_OPTIONS})
list (APPEND SOURCES ${HIPOBJS})
endif ()
if(MGE_WITH_CUDA) if(MGE_WITH_CUDA)
file(GLOB_RECURSE SOURCES_ cuda/*.cpp) file(GLOB_RECURSE SOURCES_ cuda/*.cpp)
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
...@@ -73,6 +133,19 @@ if(MGE_WITH_CUDA) ...@@ -73,6 +133,19 @@ if(MGE_WITH_CUDA)
target_link_libraries(megdnn PUBLIC cutlass) target_link_libraries(megdnn PUBLIC cutlass)
endif() endif()
if(MGE_WITH_ROCM)
target_include_directories(megdnn PUBLIC
${HIP_INCLUDE_DIR}
${MIOPEN_INCLUDE_DIR}
${ROCBLAS_INCLUDE_DIR}
${ROCRAND_INCLUDE_DIR})
target_link_directories(megdnn PUBLIC
${HIP_LIBRARY_DIR}
${MIOPEN_LIBRARY_DIR}
${ROCBLAS_LIBRARY_DIR}
${ROCRAND_LIBRARY_DIR})
endif()
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64")
if(MGE_ENABLE_CPUINFO) if(MGE_ENABLE_CPUINFO)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>)
...@@ -115,6 +188,10 @@ else() ...@@ -115,6 +188,10 @@ else()
target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS})
endif() endif()
if (MGE_WITH_ROCM)
target_link_libraries(megdnn PRIVATE ${HIPOBJS} ${MGE_ROCM_LIBS})
endif ()
if(MGE_WITH_ATLAS) if(MGE_WITH_ATLAS)
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>) target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>)
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
* -------------------------------------------------------------------------- * --------------------------------------------------------------------------
*/ */
#ifndef __CUDACC__ #if !__CUDACC__ && !__HIPCC__
#include <cmath> #include <cmath>
......
...@@ -27,6 +27,22 @@ public: ...@@ -27,6 +27,22 @@ public:
return 0; return 0;
} }
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
private: private:
......
...@@ -124,6 +124,9 @@ INST_FOR_CTYPE ...@@ -124,6 +124,9 @@ INST_FOR_CTYPE
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#endif #endif
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8 #define ct dt_int8
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
...@@ -142,6 +145,9 @@ INST_FOR_CTYPE ...@@ -142,6 +145,9 @@ INST_FOR_CTYPE
#define ct dt_qint32 #define ct dt_qint32
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_bool
INST_FOR_CTYPE
#undef ct
#undef ndim_cb #undef ndim_cb
......
...@@ -36,6 +36,9 @@ ...@@ -36,6 +36,9 @@
#include "src/rocm/argmxx/opr_impl.h" #include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h" #include "src/rocm/sleep/opr_impl.h"
#include <miopen/version.h>
#include <hip/hip_version.h>
#include <cstring> #include <cstring>
#define STR_HELPER(x) #x #define STR_HELPER(x) #x
...@@ -56,7 +59,7 @@ std::unique_ptr<Handle> Handle::make_rocm_handle(megcoreComputingHandle_t comput ...@@ -56,7 +59,7 @@ std::unique_ptr<Handle> Handle::make_rocm_handle(megcoreComputingHandle_t comput
} }
template <typename Opr> template <typename Opr>
std::unique_ptr<Opr> Handle::create_rocm_operator() { std::unique_ptr<Opr> Handle::create_rocm_operator() {
return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>(); return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>();
} }
#define INST(opr) \ #define INST(opr) \
template std::unique_ptr<opr> Handle::create_rocm_operator(); template std::unique_ptr<opr> Handle::create_rocm_operator();
...@@ -178,7 +181,8 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) ...@@ -178,7 +181,8 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
} // namespace rocm } // namespace rocm
} // namespace megdnn } // namespace megdnn
MEGDNN_VERSION_SYMBOL(HIP, HIP_VERSION);
MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH);
MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR,
MIOPEN_VERSION_PATCH); MIOPEN_VERSION_PATCH);
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -11,6 +11,11 @@ ...@@ -11,6 +11,11 @@
#include "hip_header.h" #include "hip_header.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) {
asm("s_trap 2;");
((int*)0)[0] = 1;
}
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) {
asm("s_trap 2;"); asm("s_trap 2;");
......
...@@ -36,7 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) ...@@ -36,7 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
INST(dt_float16, dt_float16, float) INST(dt_float16, dt_float16, float)
INST(dt_float16, float, float) INST(dt_float16, float, float)
INST(float, dt_float16, float) INST(float, dt_float16, float)
INST(int, float, float)
#undef cb #undef cb
#undef INST #undef INST
......
...@@ -142,6 +142,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -142,6 +142,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dtype_src, dt_uint8) \ cb(dtype_src, dt_uint8) \
cb(dtype_src, dt_float32) \ cb(dtype_src, dt_float32) \
cb(dtype_src, dt_float16) \ cb(dtype_src, dt_float16) \
cb(dtype_src, dt_bfloat16) \
#else #else
...@@ -176,6 +177,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -176,6 +177,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dt_uint8) \ cb(dt_uint8) \
cb(dt_float32) \ cb(dt_float32) \
cb(dt_float16) \ cb(dt_float16) \
cb(dt_bfloat16) \
#else #else
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
......
...@@ -259,9 +259,7 @@ void transpose_knc2nsck(const float *src, float *dst, ...@@ -259,9 +259,7 @@ void transpose_knc2nsck(const float *src, float *dst,
MEGDNN_ATTRIBUTE_TARGET("sse") MEGDNN_ATTRIBUTE_TARGET("sse")
void x86::disable_denorm() { void x86::disable_denorm() {
//printf("before: %x\n", _mm_getcsr());
_mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON)); _mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON));
//printf("after: %x\n", _mm_getcsr());
} }
} // namespace megdnn } // namespace megdnn
......
...@@ -36,6 +36,11 @@ if(MGE_WITH_ATLAS) ...@@ -36,6 +36,11 @@ if(MGE_WITH_ATLAS)
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
endif() endif()
if (MGE_WITH_ROCM)
file (GLOB_RECURSE SOURCES_ rocm/*.cpp)
list (APPEND SOURCES ${SOURCES_})
endif()
add_executable(megdnn_test ${SOURCES}) add_executable(megdnn_test ${SOURCES})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
...@@ -61,6 +66,10 @@ if(MGE_ENABLE_COVERAGE) ...@@ -61,6 +66,10 @@ if(MGE_ENABLE_COVERAGE)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --coverage") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --coverage")
endif() endif()
if (MEG_WITH_ROCM)
target_link_libraries (megdnn_test ${MGE_ROCM_LIBS})
endif ()
if(APPLE OR ANDROID) if(APPLE OR ANDROID)
target_link_libraries(megdnn_test dl) target_link_libraries(megdnn_test dl)
else() else()
......
...@@ -202,7 +202,7 @@ TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) { ...@@ -202,7 +202,7 @@ TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) {
set_rng(1, &rng_inp). set_rng(1, &rng_inp).
set_rng(2, &rng0). set_rng(2, &rng0).
set_rng(3, &rng1). set_rng(3, &rng1).
set_proxy({0, 1}); set_proxy({{0, 1}});
auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}}); auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}});
long io = 2 * 1000 * 1000 * dtype::Float32().size(); long io = 2 * 1000 * 1000 * dtype::Float32().size();
printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n",
......
...@@ -71,22 +71,22 @@ TEST_F(ROCM, MATRIX_MUL) { ...@@ -71,22 +71,22 @@ TEST_F(ROCM, MATRIX_MUL) {
BS = TensorShape{k, n}; BS = TensorShape{k, n};
CS = TensorShape{m, n}; CS = TensorShape{m, n};
TensorLayout AL, BL, CL; TensorLayout AL, BL, CL;
if (arg.Astride == 0) { if (arg.A_stride == 0) {
AL = TensorLayout(AS, dtype::Float32()); AL = TensorLayout(AS, dtype::Float32());
} else { } else {
AL = TensorLayout(AS, {ptrdiff_t(arg.Astride), 1}, AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1},
dtype::Float32()); dtype::Float32());
} }
if (arg.Bstride == 0) { if (arg.B_stride == 0) {
BL = TensorLayout(BS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32());
} else { } else {
BL = TensorLayout(BS, {ptrdiff_t(arg.Bstride), 1}, BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1},
dtype::Float32()); dtype::Float32());
} }
if (arg.Cstride == 0) { if (arg.C_stride == 0) {
CL = TensorLayout(CS, dtype::Float32()); CL = TensorLayout(CS, dtype::Float32());
} else { } else {
CL = TensorLayout(CS, {ptrdiff_t(arg.Cstride), 1}, CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1},
dtype::Float32()); dtype::Float32());
} }
checker.set_param(param).execl({AL, BL, CL}); checker.set_param(param).execl({AL, BL, CL});
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
from ..core import Buffer, Parameter from ..core import Buffer, Parameter
from ..core.device import get_default_device
from ..functional import batch_norm2d, sync_batch_norm from ..functional import batch_norm2d, sync_batch_norm
from . import init from . import init
from .module import Module from .module import Module
...@@ -79,16 +80,31 @@ class _BatchNorm(Module): ...@@ -79,16 +80,31 @@ class _BatchNorm(Module):
else: else:
exponential_average_factor = 0.0 # useless exponential_average_factor = 0.0 # useless
output = batch_norm2d( # FIXME currently rocm does not support real bn opr so we just use
inp, # sync_batch_norm(as implemented by elemwise) here,
self.running_mean, # we will fix it in the next version
self.running_var, if get_default_device() == "rocmx":
self.weight, output = sync_batch_norm(
self.bias, inp,
self.training or not self.track_running_stats, self.running_mean,
exponential_average_factor, self.running_var,
self.eps, self.weight,
) self.bias,
self.training or not self.track_running_stats,
exponential_average_factor,
self.eps,
)
else:
output = batch_norm2d(
inp,
self.running_mean,
self.running_var,
self.weight,
self.bias,
self.training or not self.track_running_stats,
exponential_average_factor,
self.eps,
)
if _ndims != 4: if _ndims != 4:
output = output.reshape(origin_shape) output = output.reshape(origin_shape)
......
...@@ -1013,7 +1013,8 @@ void add_update_impl(const DeviceTensorND& dest, ...@@ -1013,7 +1013,8 @@ void add_update_impl(const DeviceTensorND& dest,
auto&& cn = dest.comp_node(); auto&& cn = dest.comp_node();
using DT = CompNode::DeviceType; using DT = CompNode::DeviceType;
mgb_assert(cn == delta_nobrd.comp_node() && mgb_assert(cn == delta_nobrd.comp_node() &&
(cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU)); (cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU ||
cn.device_type() == DT::ROCM));
mgb_assert(dest.dtype() == delta_nobrd.dtype()); mgb_assert(dest.dtype() == delta_nobrd.dtype());
auto&& delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem( auto&& delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem(
delta_nobrd.layout().broadcast(dest.shape()), 0)); delta_nobrd.layout().broadcast(dest.shape()), 0));
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#define _HEADER_MGB_BUILD_CONFIG #define _HEADER_MGB_BUILD_CONFIG
#cmakedefine01 MGB_CUDA #cmakedefine01 MGB_CUDA
#cmakedefine01 MGB_ROCM
#cmakedefine01 MGB_CAMBRICON #cmakedefine01 MGB_CAMBRICON
#cmakedefine01 MGB_ATLAS #cmakedefine01 MGB_ATLAS
#cmakedefine01 MGB_ASSERT_LOC #cmakedefine01 MGB_ASSERT_LOC
...@@ -38,6 +39,7 @@ ...@@ -38,6 +39,7 @@
// Platform macro's // Platform macro's
#cmakedefine01 MEGDNN_WITH_CUDA #cmakedefine01 MEGDNN_WITH_CUDA
#cmakedefine01 MEGDNN_WITH_ROCM
#cmakedefine01 MEGDNN_ARMV7 #cmakedefine01 MEGDNN_ARMV7
#cmakedefine01 MEGDNN_AARCH64 #cmakedefine01 MEGDNN_AARCH64
#cmakedefine01 MEGDNN_ENABLE_FP16_NEON #cmakedefine01 MEGDNN_ENABLE_FP16_NEON
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册