diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ae9366f6d2fe45d21b83d1b94e63fd92801e443..1c4ef52e64df2106fdcc137b5d9e7d2ad62e0b8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ option(MGE_BUILD_SDK "Build load_and_run" ON) option(MGE_INFERENCE_ONLY "Build inference only library." OFF) option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) +option(MGE_WITH_ROCM "Enable ROCM support" OFF) if (APPLE) set (BUILD_SHARED_LIBS OFF) @@ -442,6 +443,10 @@ if(MGE_WITH_CAMBRICON) set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}") endif() +if (MGE_WITH_ROCM) + include(cmake/rocm.cmake) +endif () + if(MGE_WITH_ATLAS) include(cmake/aclrt.cmake) diff --git a/cmake/rocm.cmake b/cmake/rocm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..db5c0e748c534d8290b339d31c8307a5b851a302 --- /dev/null +++ b/cmake/rocm.cmake @@ -0,0 +1,100 @@ +if(NOT DEFINED HIP_PATH) + if(NOT DEFINED ENV{HIP_PATH}) + set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") + else() + set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") + endif() +endif() +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) +find_package(HIP QUIET) +if (HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) +else() + message(FATAL_ERROR "Could not find HIP. Ensure that HIP is either installed in /opt/rocm/hip or the variable HIP_PATH is set to point to the right location.") +endif() + +string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) +list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) +list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) +if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") + message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") +endif() +if (NOT ${HIP_VERSION_MINOR} STREQUAL "7") + message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") +endif() + +set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) + +set(HIP_INCLUDE_DIR ${HIP_ROOT_DIR}/../include) +set(HIP_LIBRARY_DIR ${HIP_ROOT_DIR}/../lib) + +#miopen +get_filename_component(__found_miopen_library ${HIP_ROOT_DIR}/../miopen/lib REALPATH) +find_path(MIOPEN_LIBRARY_DIR + NAMES libMIOpen.so + HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_library} + PATH_SUFFIXES lib + DOC "Path to MIOPEN library directory." ) + +if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find MIOPEN Library") +endif() + +get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) +find_path(MIOPEN_INCLUDE_DIR + NAMES miopen + HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_include} + PATH_SUFFIXES include + DOC "Path to MIOPEN include directory." ) + +if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find MIOEPN INCLUDE") +endif() + +#rocblas +get_filename_component(__found_rocblas_library ${HIP_ROOT_DIR}/../rocblas/lib REALPATH) +find_path(ROCBLAS_LIBRARY_DIR + NAMES librocblas.so + HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_library} + PATH_SUFFIXES lib + DOC "Path to ROCBLAS library directory." ) + +if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find ROCBLAS Library") +endif() + +get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) +find_path(ROCBLAS_INCLUDE_DIR + NAMES rocblas.h + HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_include} + PATH_SUFFIXES include + DOC "Path to ROCBLAS include directory." ) + +if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") +endif() + +#rocrand +get_filename_component(__found_rocrand_library ${HIP_ROOT_DIR}/../rocrand/lib REALPATH) +find_path(ROCRAND_LIBRARY_DIR + NAMES librocrand.so + HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_library} + PATH_SUFFIXES lib + DOC "Path to ROCRAND library directory." ) + +if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find ROCRAND Library") +endif() + +get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) +find_path(ROCRAND_INCLUDE_DIR + NAMES rocrand.h + HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_include} + PATH_SUFFIXES include + DOC "Path to ROCRAND include directory." ) + +if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") + message(FATAL_ERROR "Can not find ROCRAND INCLUDE") +endif() + + diff --git a/dnn/include/megcore_rocm.h b/dnn/include/megcore_rocm.h index 2a99cb46ea81f58895a9e258f006af2389b762cc..60d5270010b66c76a55e9898bda5fc01435bfb94 100644 --- a/dnn/include/megcore_rocm.h +++ b/dnn/include/megcore_rocm.h @@ -16,6 +16,8 @@ #include "hip_header.h" #include "megdnn/internal/visibility_prologue.h" +#include + namespace megcore { struct ROCMContext { hipStream_t stream = nullptr; diff --git a/dnn/include/megdnn/config/config.h b/dnn/include/megdnn/config/config.h index a428d4af7c71325df0d2ed12e30fe85d529f1521..adb7abd086a3bda48868865cf006cab4027ca6c6 100644 --- a/dnn/include/megdnn/config/config.h +++ b/dnn/include/megdnn/config/config.h @@ -11,7 +11,7 @@ #include "megbrain_build_config.h" -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(__HIPCC__) #endif // !defined(__CUDACC__) diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index bdfff4134579207ac866e7596b3cf352a0d40c74..69d1f2f9d1088041656816782ec1d6a2f35a9768 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -103,10 +103,12 @@ namespace megdnn { * \brief iterate through each dtype object that can be involved in float * numeric computing */ + #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ cb(::megdnn::dtype::Float32) \ MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ - MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ + MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) + /*! * \brief iterate through each dtype object that can be involved in integer diff --git a/dnn/include/megdnn/dtype/bfloat16.hpp b/dnn/include/megdnn/dtype/bfloat16.hpp index 7e05c3b8563062aaaf5b65401b7f09fdf9057400..8de1a1e77d4ec8e2ca702bc05d7b17a7cb0bc32d 100644 --- a/dnn/include/megdnn/dtype/bfloat16.hpp +++ b/dnn/include/megdnn/dtype/bfloat16.hpp @@ -2809,6 +2809,7 @@ namespace std { /// Numeric limits for bfloat16-precision floats. /// Because of the underlying single-precision implementation of many /// operations, it inherits some properties from `numeric_limits`. +#if !defined(__HIPCC__) template <> class numeric_limits : public numeric_limits { public: @@ -2932,6 +2933,7 @@ public: 0x0001); } }; +#endif #ifdef MEGDNN_CC_HOST #if HALF_ENABLE_CPP11_HASH diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index 6aa38bbf212f0cc69e890504bc7b6b303c3cc7ab..9dfd0bd21074ec8f649906b907548838c705ced2 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -37,6 +37,66 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") endif() +############################################################################### +# HIP_COMPILE +############################################################################### +macro (HIP_COMPILE _hip_target _hip_objs) + # Separate the sources from the options + HIP_GET_SOURCES_AND_OPTIONS(_sources + _cmake_options + _hipcc_options + _hcc_options + _nvcc_options + ${ARGN}) + HIP_PREPARE_TARGET_COMMANDS(${_hip_target} + OBJ _generated_files _source_files ${_sources} ${_cmake_options} + HIPCC_OPTIONS ${_hipcc_options} + HCC_OPTIONS ${_hcc_options} + NVCC_OPTIONS ${_nvcc_options}) + if(_source_files) + list(REMOVE_ITEM _sources ${_source_files}) + endif() + + add_custom_target(${_hip_target}) + + # set return value + set (${_hip_objs} ${_generated_files}) +endmacro() + +if (MGE_WITH_ROCM) + file (GLOB_RECURSE SOURCES_ rocm/*.cpp) + list (APPEND SOURCES ${SOURCES_}) + + # FIXME rocm may lost the first hip file, so currently we just create an + # empty file to bypass this error. + file(GLOB start.cpp.hip "" ) + list(APPEND HIP_SOURCES start.cpp.hip) + + file (GLOB_RECURSE HIPSOURCES rocm/*.cpp.hip) + set(HIP_TARGET_NAME hip_kernel) + set(_HIPCC_OPTIONS "-fPIC") + set(_HCC_OPTIONS "-fPIC") + set(_NVCC_OPTIONS "-fPIC") + + list(APPEND HIP_SOURCES ${HIPSOURCES}) + set_source_files_properties(${HIP_SOURCES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + HIP_INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/dnn + ${PROJECT_SOURCE_DIR}/dnn/include + ${PROJECT_BINARY_DIR}/dnn + ${PROJECT_BINARY_DIR}/genfiles + ${PROJECT_BINARY_DIR}/dnn/include + ${HIP_INCLUDE_DIR} + ${MIOPEN_INCLUDE_DIR} + ${ROCBLAS_INCLUDE_DIR} + ${ROCRAND_INCLUDE_DIR}) + hip_compile( + ${HIP_TARGET_NAME} HIPOBJS ${HIP_SOURCES} + HIPCC_OPTIONS ${_HIPCC_OPTIONS} + HCC_OPTIONS ${_HCC_OPTIONS} + NVCC_OPTIONS ${_NVCC_OPTIONS}) + list (APPEND SOURCES ${HIPOBJS}) +endif () + if(MGE_WITH_CUDA) file(GLOB_RECURSE SOURCES_ cuda/*.cpp) list(APPEND SOURCES ${SOURCES_}) @@ -73,6 +133,19 @@ if(MGE_WITH_CUDA) target_link_libraries(megdnn PUBLIC cutlass) endif() +if(MGE_WITH_ROCM) + target_include_directories(megdnn PUBLIC + ${HIP_INCLUDE_DIR} + ${MIOPEN_INCLUDE_DIR} + ${ROCBLAS_INCLUDE_DIR} + ${ROCRAND_INCLUDE_DIR}) + target_link_directories(megdnn PUBLIC + ${HIP_LIBRARY_DIR} + ${MIOPEN_LIBRARY_DIR} + ${ROCBLAS_LIBRARY_DIR} + ${ROCRAND_LIBRARY_DIR}) +endif() + if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") if(MGE_ENABLE_CPUINFO) target_link_libraries(megdnn PRIVATE $) @@ -115,6 +188,10 @@ else() target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) endif() +if (MGE_WITH_ROCM) + target_link_libraries(megdnn PRIVATE ${HIPOBJS} ${MGE_ROCM_LIBS}) +endif () + if(MGE_WITH_ATLAS) if (BUILD_SHARED_LIBS) target_link_libraries(megdnn PRIVATE $) diff --git a/dnn/src/common/elemwise/erfinv.h b/dnn/src/common/elemwise/erfinv.h index 7cf0b56571b5968ae649bf3c6903ce08d14705de..2429014c335ef7af602c461e8748da8609be7f64 100644 --- a/dnn/src/common/elemwise/erfinv.h +++ b/dnn/src/common/elemwise/erfinv.h @@ -40,7 +40,7 @@ * -------------------------------------------------------------------------- */ -#ifndef __CUDACC__ +#if !__CUDACC__ && !__HIPCC__ #include diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.h b/dnn/src/rocm/batched_matrix_mul/opr_impl.h index a4dfbc141d7a1d1bcca262b4169a5bcb5e8f1e4f..60ca11172af040aa1ccb9dceed80e7c0acbfbfd1 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.h +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.h @@ -27,6 +27,22 @@ public: return 0; } + std::vector get_all_algorithms( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override { + return {}; + } + + Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/, + size_t /*workspace_limit_in_bytes*/, + bool /* reproducible */) override { + return nullptr; + } + + const char* get_algorithm_set_name() const override { return "DEFAULT"; } + bool is_thread_safe() const override { return true; } private: diff --git a/dnn/src/rocm/elemwise_helper.cpp b/dnn/src/rocm/elemwise_helper.cpp index 85dc7f34307c234082bf6cc881b62779d9a78531..209e08b62f34bd781ea1ca3ecf2d8c835c67b5fa 100644 --- a/dnn/src/rocm/elemwise_helper.cpp +++ b/dnn/src/rocm/elemwise_helper.cpp @@ -124,6 +124,9 @@ INST_FOR_CTYPE INST_FOR_CTYPE #undef ct #endif +#define ct dt_bfloat16 +INST_FOR_CTYPE +#undef ct #define ct dt_int8 INST_FOR_CTYPE #undef ct @@ -142,6 +145,9 @@ INST_FOR_CTYPE #define ct dt_qint32 INST_FOR_CTYPE #undef ct +#define ct dt_bool +INST_FOR_CTYPE +#undef ct #undef ndim_cb diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index cefbe8cf4506d68bf1c0beff7f7a74120ad3a5ab..62da48c2831eac25b0d55465a6aa231b96b9e634 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -36,6 +36,9 @@ #include "src/rocm/argmxx/opr_impl.h" #include "src/rocm/sleep/opr_impl.h" +#include +#include + #include #define STR_HELPER(x) #x @@ -56,7 +59,7 @@ std::unique_ptr Handle::make_rocm_handle(megcoreComputingHandle_t comput } template std::unique_ptr Handle::create_rocm_operator() { - return static_cast(this)->create_operator(); + return static_cast(this)->create_operator(); } #define INST(opr) \ template std::unique_ptr Handle::create_rocm_operator(); @@ -178,7 +181,8 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) } // namespace rocm } // namespace megdnn -MEGDNN_VERSION_SYMBOL(HIP, HIP_VERSION); + +MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH); MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH); // vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip index 88419db77f50d4f4251d7e25c5ef64bc5ef89ff7..6bccf0ef3c1baba2b01460ed0196e57963089fd5 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip @@ -11,6 +11,11 @@ #include "hip_header.h" #include "megdnn/dtype.h" +__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + #if !MEGDNN_DISABLE_FLOAT16 __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { asm("s_trap 2;"); diff --git a/dnn/src/rocm/reduce/reduce.cpp.hip b/dnn/src/rocm/reduce/reduce.cpp.hip index 2228a04530e3b051de3f0a6d6d4eb9abd3529668..88dfa632675a362784365418e6af641ef4834063 100644 --- a/dnn/src/rocm/reduce/reduce.cpp.hip +++ b/dnn/src/rocm/reduce/reduce.cpp.hip @@ -36,7 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) INST(dt_float16, dt_float16, float) INST(dt_float16, float, float) INST(float, dt_float16, float) - +INST(int, float, float) #undef cb #undef INST diff --git a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip index db37d7aad3e95209f2b0309665166a51f46a8e3f..22662adf4f8fa025f069eb17371dce0714da486a 100644 --- a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip +++ b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip @@ -142,6 +142,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dtype_src, dt_uint8) \ cb(dtype_src, dt_float32) \ cb(dtype_src, dt_float16) \ + cb(dtype_src, dt_bfloat16) \ #else @@ -176,6 +177,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cb(dt_uint8) \ cb(dt_float32) \ cb(dt_float16) \ + cb(dt_bfloat16) \ #else #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ diff --git a/dnn/src/x86/utils.cpp b/dnn/src/x86/utils.cpp index 11a0d7f7f9f3719614e29f4521d617b8231b75f6..54346b08ba542a0d978d2f703fbeb7acd7a167ce 100644 --- a/dnn/src/x86/utils.cpp +++ b/dnn/src/x86/utils.cpp @@ -259,9 +259,7 @@ void transpose_knc2nsck(const float *src, float *dst, MEGDNN_ATTRIBUTE_TARGET("sse") void x86::disable_denorm() { - //printf("before: %x\n", _mm_getcsr()); _mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON)); - //printf("after: %x\n", _mm_getcsr()); } } // namespace megdnn diff --git a/dnn/test/CMakeLists.txt b/dnn/test/CMakeLists.txt index 2770fdbee4bf07d59a2445403d72b070c66238ab..b37be5d9a1aa3df0919be1795693424a5681eca3 100644 --- a/dnn/test/CMakeLists.txt +++ b/dnn/test/CMakeLists.txt @@ -36,6 +36,11 @@ if(MGE_WITH_ATLAS) list(APPEND SOURCES ${SOURCES_}) endif() +if (MGE_WITH_ROCM) + file (GLOB_RECURSE SOURCES_ rocm/*.cpp) + list (APPEND SOURCES ${SOURCES_}) +endif() + add_executable(megdnn_test ${SOURCES}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") @@ -61,6 +66,10 @@ if(MGE_ENABLE_COVERAGE) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --coverage") endif() +if (MEG_WITH_ROCM) + target_link_libraries (megdnn_test ${MGE_ROCM_LIBS}) +endif () + if(APPLE OR ANDROID) target_link_libraries(megdnn_test dl) else() diff --git a/dnn/test/rocm/indexing_multi_axis_vec.cpp b/dnn/test/rocm/indexing_multi_axis_vec.cpp index 5b425e10a98daf86037ec48946e3eaae9c519e2e..12d4834ee13755df7618535084d29bd6a69a242b 100644 --- a/dnn/test/rocm/indexing_multi_axis_vec.cpp +++ b/dnn/test/rocm/indexing_multi_axis_vec.cpp @@ -202,7 +202,7 @@ TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) { set_rng(1, &rng_inp). set_rng(2, &rng0). set_rng(3, &rng1). - set_proxy({0, 1}); + set_proxy({{0, 1}}); auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}}); long io = 2 * 1000 * 1000 * dtype::Float32().size(); printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", diff --git a/dnn/test/rocm/matrix_mul.cpp b/dnn/test/rocm/matrix_mul.cpp index 0c78d75655d93cac70de9855e6454df6f549ce36..87a2d2e0afa4870beacf5a4f6eb3451908c91d07 100644 --- a/dnn/test/rocm/matrix_mul.cpp +++ b/dnn/test/rocm/matrix_mul.cpp @@ -71,22 +71,22 @@ TEST_F(ROCM, MATRIX_MUL) { BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; - if (arg.Astride == 0) { + if (arg.A_stride == 0) { AL = TensorLayout(AS, dtype::Float32()); } else { - AL = TensorLayout(AS, {ptrdiff_t(arg.Astride), 1}, + AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, dtype::Float32()); } - if (arg.Bstride == 0) { + if (arg.B_stride == 0) { BL = TensorLayout(BS, dtype::Float32()); } else { - BL = TensorLayout(BS, {ptrdiff_t(arg.Bstride), 1}, + BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, dtype::Float32()); } - if (arg.Cstride == 0) { + if (arg.C_stride == 0) { CL = TensorLayout(CS, dtype::Float32()); } else { - CL = TensorLayout(CS, {ptrdiff_t(arg.Cstride), 1}, + CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, dtype::Float32()); } checker.set_param(param).execl({AL, BL, CL}); diff --git a/python_module/megengine/module/batchnorm.py b/python_module/megengine/module/batchnorm.py index a11e001a8956cdc948f068f5fb103c37c25427a2..ba7556168c435f22974c32ef7fd421f883f7fda2 100644 --- a/python_module/megengine/module/batchnorm.py +++ b/python_module/megengine/module/batchnorm.py @@ -9,6 +9,7 @@ import numpy as np from ..core import Buffer, Parameter +from ..core.device import get_default_device from ..functional import batch_norm2d, sync_batch_norm from . import init from .module import Module @@ -79,16 +80,31 @@ class _BatchNorm(Module): else: exponential_average_factor = 0.0 # useless - output = batch_norm2d( - inp, - self.running_mean, - self.running_var, - self.weight, - self.bias, - self.training or not self.track_running_stats, - exponential_average_factor, - self.eps, - ) + # FIXME currently rocm does not support real bn opr so we just use + # sync_batch_norm(as implemented by elemwise) here, + # we will fix it in the next version + if get_default_device() == "rocmx": + output = sync_batch_norm( + inp, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, + self.eps, + ) + else: + output = batch_norm2d( + inp, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, + self.eps, + ) if _ndims != 4: output = output.reshape(origin_shape) diff --git a/python_module/src/cpp/megbrain_wrap.cpp b/python_module/src/cpp/megbrain_wrap.cpp index a71ada0d00d3949bca88110c6457db8ee0bc4c71..bcb85ae1789952ebb938b07c73c2bc802ad6bd69 100644 --- a/python_module/src/cpp/megbrain_wrap.cpp +++ b/python_module/src/cpp/megbrain_wrap.cpp @@ -1013,7 +1013,8 @@ void add_update_impl(const DeviceTensorND& dest, auto&& cn = dest.comp_node(); using DT = CompNode::DeviceType; mgb_assert(cn == delta_nobrd.comp_node() && - (cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU)); + (cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU || + cn.device_type() == DT::ROCM)); mgb_assert(dest.dtype() == delta_nobrd.dtype()); auto&& delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem( delta_nobrd.layout().broadcast(dest.shape()), 0)); diff --git a/src/megbrain_build_config.h.in b/src/megbrain_build_config.h.in index f923376ba5cc52588ee5e8288920d90ca6d7bfa1..5db6df2e1916f795923bfb42719e700a71b363c7 100644 --- a/src/megbrain_build_config.h.in +++ b/src/megbrain_build_config.h.in @@ -13,6 +13,7 @@ #define _HEADER_MGB_BUILD_CONFIG #cmakedefine01 MGB_CUDA +#cmakedefine01 MGB_ROCM #cmakedefine01 MGB_CAMBRICON #cmakedefine01 MGB_ATLAS #cmakedefine01 MGB_ASSERT_LOC @@ -38,6 +39,7 @@ // Platform macro's #cmakedefine01 MEGDNN_WITH_CUDA +#cmakedefine01 MEGDNN_WITH_ROCM #cmakedefine01 MEGDNN_ARMV7 #cmakedefine01 MEGDNN_AARCH64 #cmakedefine01 MEGDNN_ENABLE_FP16_NEON