提交 9f2af209 编写于 作者: M Megvii Engine Team

feat(mgb): add enflame comp node

GitOrigin-RevId: 478c8538aa890dddf4e6ca95d1c4bb8a8b49ed8e
上级 15d3b3b9
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#endif #endif
#if MEGDNN_WITH_CUDA #if MEGDNN_WITH_CUDA
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#endif #endif
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#endif #endif
#if MEGDNN_WITH_ROCM #if MEGDNN_WITH_ROCM
#include "src/rocm/megcore/computing_context.hpp" #include "src/rocm/megcore/computing_context.hpp"
#endif #endif
......
...@@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { ...@@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
} }
dev_type = DeviceType::MULTITHREAD; dev_type = DeviceType::MULTITHREAD;
ptr += 11; ptr += 11;
} else { }
else {
if (ptr[1] != 'p' || ptr[2] != 'u') { if (ptr[1] != 'p' || ptr[2] != 'u') {
err(); err();
} }
...@@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { ...@@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
//! num_steam store the nr_thread //! num_steam store the nr_thread
std::swap(num_dev, num_stream); std::swap(num_dev, num_stream);
} }
return {dev_type, num_dev, {num_stream}}; return {dev_type, num_dev, {num_stream}};
} }
......
...@@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( ...@@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
{ {
auto type = cn_impl->env().property().type; auto type = cn_impl->env().property().type;
mgb_throw_if( mgb_throw_if(type != CompNode::DeviceType::CPU
type != CompNode::DeviceType::CPU && && type != CompNode::DeviceType::CUDA
type != CompNode::DeviceType::CUDA && type != CompNode::DeviceType::ATLAS
&& type != CompNode::DeviceType::ATLAS && ,
type != CompNode::DeviceType::CAMBRICON, MegBrainError,
MegBrainError, "currently CPU can only wait for CPU, CUDA, ATLAS"
"currently CPU can only wait for CPU, CUDA, ATLAS, CAMBRICON"
); );
} }
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#endif #endif
using namespace mgb; using namespace mgb;
/* =================== MegDNNHandle =================== */ /* =================== MegDNNHandle =================== */
...@@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, ...@@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
} }
#endif #endif
#if MGB_ATLAS #if MGB_ATLAS
void mgb::_on_atlas_error(const char* expr, int err, const char* file, void mgb::_on_atlas_error(const char* expr, int err, const char* file,
...@@ -421,6 +423,7 @@ void CompNodeEnv::fini() { ...@@ -421,6 +423,7 @@ void CompNodeEnv::fini() {
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream));
} }
#endif #endif
#if MGB_ROCM #if MGB_ROCM
if (m_property.type == DeviceType::ROCM) { if (m_property.type == DeviceType::ROCM) {
m_rocm_env.activate(); m_rocm_env.activate();
...@@ -440,6 +443,7 @@ void CompNodeEnv::fini() { ...@@ -440,6 +443,7 @@ void CompNodeEnv::fini() {
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream));
} }
#endif #endif
} }
#if MGB_ENABLE_COMP_NODE_ASYNC_INIT #if MGB_ENABLE_COMP_NODE_ASYNC_INIT
......
...@@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() { ...@@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() {
#endif #endif
} }
AtlasError::AtlasError(const std::string &msg): AtlasError::AtlasError(const std::string &msg):
SystemError(msg) SystemError(msg)
{ {
......
...@@ -82,7 +82,7 @@ class CompNode { ...@@ -82,7 +82,7 @@ class CompNode {
CAMBRICON = 3, CAMBRICON = 3,
ROCM = 8, ROCM = 8,
ATLAS = 9, ATLAS = 9,
MULTITHREAD, MULTITHREAD = 11,
MAX_DEVICE_ID, MAX_DEVICE_ID,
}; };
static constexpr size_t NR_DEVICE_TYPE = static constexpr size_t NR_DEVICE_TYPE =
......
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
#endif //MGB_ENABLE_LOGGING #endif //MGB_ENABLE_LOGGING
#endif //MGB_CUDA #endif //MGB_CUDA
#if MGB_ATLAS #if MGB_ATLAS
#include "megcore_atlas.h" #include "megcore_atlas.h"
#include <atomic> #include <atomic>
...@@ -205,6 +206,7 @@ namespace mgb { ...@@ -205,6 +206,7 @@ namespace mgb {
#endif #endif
#if MGB_ROCM #if MGB_ROCM
[[noreturn]] void _on_hip_error(const char* expr, hipError_t err, [[noreturn]] void _on_hip_error(const char* expr, hipError_t err,
const char* file, const char* func, int line); const char* file, const char* func, int line);
...@@ -369,6 +371,7 @@ public: ...@@ -369,6 +371,7 @@ public:
const ContinuationCtx<cudaStream_t>& cont); const ContinuationCtx<cudaStream_t>& cont);
#endif #endif
#if MGB_ATLAS #if MGB_ATLAS
struct AtlasEnv { struct AtlasEnv {
int device = -1; int device = -1;
......
...@@ -139,6 +139,11 @@ public: ...@@ -139,6 +139,11 @@ public:
CudaError(const std::string& msg); CudaError(const std::string& msg);
}; };
class EnFlameError final : public SystemError {
public:
EnFlameError(const std::string& msg);
};
class AtlasError final: public SystemError { class AtlasError final: public SystemError {
public: public:
AtlasError(const std::string& msg); AtlasError(const std::string& msg);
......
...@@ -166,6 +166,7 @@ TEST(TestCompNode, Load) { ...@@ -166,6 +166,7 @@ TEST(TestCompNode, Load) {
ASSERT_NE(atlas0, atlas1); ASSERT_NE(atlas0, atlas1);
#endif #endif
} }
TEST(TestCompNode, FreeAfterFinalize) { TEST(TestCompNode, FreeAfterFinalize) {
...@@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) { ...@@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) {
#endif #endif
#endif // MGB_CAMBRICON #endif // MGB_CAMBRICON
#if MGB_ATLAS #if MGB_ATLAS
TEST(TestCompNodeAtlas, D2DCopy) { TEST(TestCompNodeAtlas, D2DCopy) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册