#include "hcc_detail/hcc_defs_prologue.h" #include "src/common/handle_impl.h" #include "src/common/version_symbol.h" #include "src/rocm/handle.h" #include "src/rocm/miopen_with_check.h" #include "src/rocm/utils.h" #include "src/rocm/adaptive_pooling/opr_impl.h" #include "src/rocm/add_update/opr_impl.h" #include "src/rocm/argmxx/opr_impl.h" #include "src/rocm/argsort/opr_impl.h" #include "src/rocm/batch_normalization/opr_impl.h" #include "src/rocm/batched_matrix_mul/opr_impl.h" #include "src/rocm/checksum/opr_impl.h" #include "src/rocm/convolution/opr_impl.h" #include "src/rocm/elemwise/opr_impl.h" #include "src/rocm/eye/opr_impl.h" #include "src/rocm/fill/opr_impl.h" #include "src/rocm/indexing_multi_axis_vec/opr_impl.h" #include "src/rocm/indexing_one_hot/opr_impl.h" #include "src/rocm/linspace/opr_impl.h" #include "src/rocm/matrix_mul/opr_impl.h" #include "src/rocm/param_pack/opr_impl.h" #include "src/rocm/pooling/opr_impl.h" #include "src/rocm/powc/opr_impl.h" #include "src/rocm/reduce/opr_impl.h" #include "src/rocm/relayout/opr_impl.h" #include "src/rocm/rng/opr_impl.h" #include "src/rocm/sleep/opr_impl.h" #include "src/rocm/topk/opr_impl.h" #include "src/rocm/type_cvt/opr_impl.h" #include #include #include #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) #define MIOPEN_VERSION_STR \ STR(MIOPEN_VERSION_MAJOR) \ "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH) #pragma message "compile with MIOpen " MIOPEN_VERSION_STR " " #undef STR #undef STR_HELPER namespace megdnn { std::unique_ptr Handle::make_rocm_handle( megcoreComputingHandle_t computing_handle) { return std::make_unique(computing_handle); } template std::unique_ptr Handle::create_rocm_operator() { return static_cast(this)->create_operator(); } #define INST(opr) template std::unique_ptr Handle::create_rocm_operator(); MEGDNN_FOREACH_OPR_CLASS(INST) #undef INST } // namespace megdnn namespace megdnn { namespace rocm { HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) : HandleImplHelper(comp_handle, HandleType::ROCM) { // Get megcore device handle megcoreDeviceHandle_t dev_handle; megcoreGetDeviceHandle(comp_handle, &dev_handle); int dev_id; megcoreGetDeviceID(dev_handle, &dev_id); if (dev_id < 0) { hip_check(hipGetDevice(&dev_id)); } m_device_id = dev_id; hip_check(hipGetDeviceProperties(&m_device_prop, dev_id)); // Get stream from MegCore computing handle. //! no version check megcore::getROCMContext(comp_handle, &m_megcore_context); rocblas_check(rocblas_create_handle(&m_rocblas_handle)); //! must call miopenCreateWithStream() to create miopen handle, then the //! rocblas_handle of miopen will set to be the same stream , otherwise //! miopen create rocblas_handle with default stream miopen_check(miopenCreateWithStream(&m_miopen_handle, stream())); // Set stream for miopen and rocblas handles. rocblas_check(rocblas_set_stream(m_rocblas_handle, stream())); // Note that all rocblas scalars (alpha, beta) and scalar results such as // dot output resides at device side. rocblas_check( rocblas_set_pointer_mode(m_rocblas_handle, rocblas_pointer_mode_device)); // init const scalars hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars))); ConstScalars const_scalars_val; const_scalars_val.init(); hip_check(hipMemcpyAsync( m_const_scalars, &const_scalars_val, sizeof(ConstScalars), hipMemcpyHostToDevice, stream())); hip_check(hipStreamSynchronize(stream())); } HandleImpl::~HandleImpl() noexcept { miopen_check(miopenDestroy(m_miopen_handle)); rocblas_check(rocblas_destroy_handle(m_rocblas_handle)); hip_check(hipFree(m_const_scalars)); } void HandleImpl::ConstScalars::init() { #if !MEGDNN_DISABLE_FLOAT16 f16[0].megdnn_x = 0; f16[1].megdnn_x = 1; #endif f32[0] = 0; f32[1] = 1; i32[0] = 0; i32[1] = 1; } template std::unique_ptr HandleImpl::create_operator() { megdnn_throw("unsupported rocm opr"); return nullptr; } size_t HandleImpl::alignment_requirement() const { auto&& prop = m_device_prop; MEGDNN_MARK_USED_VAR(prop); //! for now, texture functions are not supported. return 1u; } bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { // is contiguous or can be hold by // relayout::param::try_copy_2d/try_copy_last_contig return src.is_contiguous() || src.stride[src.ndim - 1] == 1; } MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK); MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC); MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec); MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec); MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" #pragma GCC diagnostic ignored "-Winstantiation-after-specialization" MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) #pragma GCC diagnostic pop } // namespace rocm } // namespace megdnn MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH); MEGDNN_VERSION_SYMBOL3( MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH); // vim: syntax=cpp.doxygen