custom_cuda.cpp 643 字节
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#include "megbrain/common.h"

#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"

using namespace mgb;

namespace custom {

const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
    mgb_assert(
            rt_args.device().enumv() == DeviceEnum::cuda,
            "devive type should be cuda.");
    const CompNodeEnv& env =
            CompNodeEnv::from_comp_node(to_builtin<CompNode, Device>(rt_args.device()));
    const CompNodeEnv::CudaEnv& cuda_env = env.cuda_env();
    return {cuda_env.device, cuda_env.stream};
}

}  // namespace custom