提交 1bce857c 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/opr-mm): use comp_node of config as default in CollectiveComm

GitOrigin-RevId: 6b43c9fc93a5bdcffa12d81179c1d74d6f96ce56
上级 27205461
...@@ -107,27 +107,9 @@ protected: ...@@ -107,27 +107,9 @@ protected:
} }
} }
static void add_output_var_all2all(CollectiveComm* opr) {
mgb_assert(opr->nr_devices() >= 2);
auto pname = get_param_name(opr->param());
// sublinear would setup opr->config if inputs.size() is 1,
// bypass this situation
mgb_assert(
!opr->config().has_comp_node_set() || opr->input().size() == 1,
"comp node should not be set in %s mode", pname);
for (auto i : opr->input()) {
opr->add_output(ssprintf("%s:%s", pname, i->cname()))
->comp_node(i->comp_node());
}
}
public: public:
virtual ~ModeTrait() = default; virtual ~ModeTrait() = default;
//! add output var for the opr
virtual void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) = 0;
/*! /*!
* \brief the vars on whose comp node the computing should be performed * \brief the vars on whose comp node the computing should be performed
* if None, output vars would be used * if None, output vars would be used
...@@ -188,11 +170,6 @@ public: ...@@ -188,11 +170,6 @@ public:
}; };
class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { ...@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
}; };
class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -292,11 +264,6 @@ protected: ...@@ -292,11 +264,6 @@ protected:
class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
public ModeTrait { public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm*, void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { ...@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
public ModeTrait { public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { ...@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
}; };
class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}
const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}
void get_output_var_shape(const CollectiveComm*, void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { ...@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
}; };
class CollectiveComm::ModeTrait::GATHER : public ModeTrait { class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { ...@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
}; };
class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}
const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { ...@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
}; };
class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr, void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) override { TensorShapeArray& oshp) override {
...@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm( ...@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm(
m_key(key), m_key(key),
m_dev_buffers(dev_buffer_arr), m_dev_buffers(dev_buffer_arr),
m_disable{disable} { m_disable{disable} {
for (auto i : inputs) { // add input
mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA, mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
"CollectiveComm currectly only supports CUDA"); if (inputs.size() > 0) {
} mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
for (auto i : config.comp_node()) {
mgb_assert(i.device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA"); "CollectiveComm currectly only supports CUDA");
add_input({inputs[0]});
} }
CompNode::UnorderedSet inp_cn; // add output
ThinHashSet<int> inp_dev; add_output(ssprintf("%s:%s", get_param_name(param), key.c_str()));
for (auto i : inputs) { // set comp node
add_input({i}); const auto& cns = config.comp_node();
inp_cn.insert(i->comp_node()); mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
inp_dev.insert( if (cns.size() > 0) {
CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device); mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]);
} else {
output(0)->comp_node(inputs[0]->comp_node());
} }
mgb_assert(
inp_dev.size() == inputs.size(),
"CollectiveComm inputs should not contain duplicated input device");
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);
// set debug flag
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
m_debug_mode = true; m_debug_mode = true;
} }
// deduplication
add_equivalence_component<PODHash<Param>>(&m_param); add_equivalence_component<PODHash<Param>>(&m_param);
add_equivalence_component<PODHash<size_t>>(&m_nr_devices); add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest(); m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册