handle.cpp 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9
#include "hcc_detail/hcc_defs_prologue.h"

#include "src/common/handle_impl.h"
#include "src/common/version_symbol.h"

#include "src/rocm/handle.h"
#include "src/rocm/miopen_with_check.h"
#include "src/rocm/utils.h"

M
Megvii Engine Team 已提交
10 11 12 13 14 15
#include "src/rocm/adaptive_pooling/opr_impl.h"
#include "src/rocm/add_update/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/argsort/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
#include "src/rocm/batched_matrix_mul/opr_impl.h"
16 17 18 19
#include "src/rocm/checksum/opr_impl.h"
#include "src/rocm/convolution/opr_impl.h"
#include "src/rocm/elemwise/opr_impl.h"
#include "src/rocm/eye/opr_impl.h"
M
Megvii Engine Team 已提交
20 21 22 23 24 25
#include "src/rocm/fill/opr_impl.h"
#include "src/rocm/indexing_multi_axis_vec/opr_impl.h"
#include "src/rocm/indexing_one_hot/opr_impl.h"
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/matrix_mul/opr_impl.h"
#include "src/rocm/param_pack/opr_impl.h"
26
#include "src/rocm/pooling/opr_impl.h"
M
Megvii Engine Team 已提交
27
#include "src/rocm/powc/opr_impl.h"
28 29
#include "src/rocm/reduce/opr_impl.h"
#include "src/rocm/relayout/opr_impl.h"
M
Megvii Engine Team 已提交
30
#include "src/rocm/rng/opr_impl.h"
31
#include "src/rocm/sleep/opr_impl.h"
M
Megvii Engine Team 已提交
32 33
#include "src/rocm/topk/opr_impl.h"
#include "src/rocm/type_cvt/opr_impl.h"
34

35
#include <hip/hip_version.h>
M
Megvii Engine Team 已提交
36
#include <miopen/version.h>
37

38 39 40
#include <cstring>

#define STR_HELPER(x) #x
M
Megvii Engine Team 已提交
41
#define STR(x)        STR_HELPER(x)
42 43 44 45 46 47 48 49 50 51 52

#define MIOPEN_VERSION_STR    \
    STR(MIOPEN_VERSION_MAJOR) \
    "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH)

#pragma message "compile with MIOpen " MIOPEN_VERSION_STR " "

#undef STR
#undef STR_HELPER

namespace megdnn {
M
Megvii Engine Team 已提交
53 54
std::unique_ptr<Handle> Handle::make_rocm_handle(
        megcoreComputingHandle_t computing_handle) {
55 56 57 58
    return std::make_unique<rocm::HandleImpl>(computing_handle);
}
template <typename Opr>
std::unique_ptr<Opr> Handle::create_rocm_operator() {
59
    return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>();
60
}
M
Megvii Engine Team 已提交
61
#define INST(opr) template std::unique_ptr<opr> Handle::create_rocm_operator();
62 63
MEGDNN_FOREACH_OPR_CLASS(INST)
#undef INST
M
Megvii Engine Team 已提交
64
}  // namespace megdnn
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

namespace megdnn {
namespace rocm {

HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
        : HandleImplHelper(comp_handle, HandleType::ROCM) {
    // Get megcore device handle
    megcoreDeviceHandle_t dev_handle;
    megcoreGetDeviceHandle(comp_handle, &dev_handle);
    int dev_id;
    megcoreGetDeviceID(dev_handle, &dev_id);
    if (dev_id < 0) {
        hip_check(hipGetDevice(&dev_id));
    }
    m_device_id = dev_id;
    hip_check(hipGetDeviceProperties(&m_device_prop, dev_id));
    // Get stream from MegCore computing handle.
    //! no version check
    megcore::getROCMContext(comp_handle, &m_megcore_context);
    rocblas_check(rocblas_create_handle(&m_rocblas_handle));
    //! must call miopenCreateWithStream() to create miopen handle, then the
    //! rocblas_handle of miopen will set to be the same stream , otherwise
    //! miopen create rocblas_handle with default stream
    miopen_check(miopenCreateWithStream(&m_miopen_handle, stream()));

    // Set stream for miopen and rocblas handles.
    rocblas_check(rocblas_set_stream(m_rocblas_handle, stream()));

    // Note that all rocblas scalars (alpha, beta) and scalar results such as
    // dot output resides at device side.
M
Megvii Engine Team 已提交
95 96
    rocblas_check(
            rocblas_set_pointer_mode(m_rocblas_handle, rocblas_pointer_mode_device));
97 98 99 100 101

    // init const scalars
    hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars)));
    ConstScalars const_scalars_val;
    const_scalars_val.init();
M
Megvii Engine Team 已提交
102 103 104
    hip_check(hipMemcpyAsync(
            m_const_scalars, &const_scalars_val, sizeof(ConstScalars),
            hipMemcpyHostToDevice, stream()));
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    hip_check(hipStreamSynchronize(stream()));
}

HandleImpl::~HandleImpl() noexcept {
    miopen_check(miopenDestroy(m_miopen_handle));
    rocblas_check(rocblas_destroy_handle(m_rocblas_handle));
    hip_check(hipFree(m_const_scalars));
}

void HandleImpl::ConstScalars::init() {
#if !MEGDNN_DISABLE_FLOAT16
    f16[0].megdnn_x = 0;
    f16[1].megdnn_x = 1;
#endif
    f32[0] = 0;
    f32[1] = 1;
    i32[0] = 0;
    i32[1] = 1;
}

template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
127 128 129
    megdnn_throw(
            "unsupported rocm opr, try export RUNTIME_OVERRIDE_LOG_LEVEL=0 to get more "
            "info");
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    return nullptr;
}

size_t HandleImpl::alignment_requirement() const {
    auto&& prop = m_device_prop;
    MEGDNN_MARK_USED_VAR(prop);
    //! for now, texture functions are not supported.
    return 1u;
}

bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
    // is contiguous or can be hold by
    // relayout::param::try_copy_2d/try_copy_last_contig
    return src.is_contiguous() || src.stride[src.ndim - 1] == 1;
}

146 147
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward);
148 149 150 151 152 153 154 155
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward);
156 157
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward);
158 159
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt);
160
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK);
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
177 178
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
179
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
M
Megvii Engine Team 已提交
180
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);
181 182 183 184 185 186 187 188 189 190

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop

}  // namespace rocm
}  // namespace megdnn

191
MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH);
M
Megvii Engine Team 已提交
192 193
MEGDNN_VERSION_SYMBOL3(
        MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
194
// vim: syntax=cpp.doxygen