custom_cuda.cpp 2.0 KB
Newer Older
1
#include "megbrain/common.h"
2 3 4
#include "megbrain_build_config.h"

#if MGB_CUSTOM_OP
5 6 7 8 9 10 11 12 13

#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"

using namespace mgb;

namespace custom {

14 15
#if MGB_CUDA

16 17 18 19 20 21 22 23 24 25
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
    mgb_assert(
            rt_args.device().enumv() == DeviceEnum::cuda,
            "devive type should be cuda.");
    const CompNodeEnv& env =
            CompNodeEnv::from_comp_node(to_builtin<CompNode, Device>(rt_args.device()));
    const CompNodeEnv::CudaEnv& cuda_env = env.cuda_env();
    return {cuda_env.device, cuda_env.stream};
}

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
int get_cuda_device_id(Device device) {
    auto cn = to_builtin<CompNode>(device);
    return CompNodeEnv::from_comp_node(cn).cuda_env().device;
}

const cudaDeviceProp* get_cuda_device_props(Device device) {
    auto cn = to_builtin<CompNode>(device);
    return &CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
}

cudaStream_t get_cuda_stream(Device device) {
    auto cn = to_builtin<CompNode>(device);
    return CompNodeEnv::from_comp_node(cn).cuda_env().stream;
}

#else

const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
    mgb_assert(
            false,
            "megbrain is not support cuda now, please rebuild megbrain with CUDA "
            "ENABLED");
}

int get_cuda_device_id(Device device) {
    mgb_assert(
            false,
            "megbrain is not support cuda now, please rebuild megbrain with CUDA "
            "ENABLED");
}

const cudaDeviceProp* get_cuda_device_props(Device device) {
    mgb_assert(
            false,
            "megbrain is not support cuda now, please rebuild megbrain with CUDA "
            "ENABLED");
}

cudaStream_t get_cuda_stream(Device device) {
    mgb_assert(
            false,
            "megbrain is not support cuda now, please rebuild megbrain with CUDA "
            "ENABLED");
}

#endif

73
}  // namespace custom
74 75

#endif