提交 1bce857c 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/opr-mm): use comp_node of config as default in CollectiveComm

GitOrigin-RevId: 6b43c9fc93a5bdcffa12d81179c1d74d6f96ce56
上级 27205461
......@@ -107,27 +107,9 @@ protected:
}
}
static void add_output_var_all2all(CollectiveComm* opr) {
mgb_assert(opr->nr_devices() >= 2);
auto pname = get_param_name(opr->param());
// sublinear would setup opr->config if inputs.size() is 1,
// bypass this situation
mgb_assert(
!opr->config().has_comp_node_set() || opr->input().size() == 1,
"comp node should not be set in %s mode", pname);
for (auto i : opr->input()) {
opr->add_output(ssprintf("%s:%s", pname, i->cname()))
->comp_node(i->comp_node());
}
}
public:
virtual ~ModeTrait() = default;
//! add output var for the opr
virtual void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) = 0;
/*!
* \brief the vars on whose comp node the computing should be performed
* if None, output vars would be used
......@@ -188,11 +170,6 @@ public:
};
class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
};
class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -292,11 +264,6 @@ protected:
class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet& inp_cn) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
};
class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}
const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}
void get_output_var_shape(const CollectiveComm*,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
};
class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
};
class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}
const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
};
class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
......@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm(
m_key(key),
m_dev_buffers(dev_buffer_arr),
m_disable{disable} {
for (auto i : inputs) {
mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
}
for (auto i : config.comp_node()) {
mgb_assert(i.device_type() == CompNode::DeviceType::CUDA,
// add input
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
if (inputs.size() > 0) {
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
add_input({inputs[0]});
}
CompNode::UnorderedSet inp_cn;
ThinHashSet<int> inp_dev;
// add output
add_output(ssprintf("%s:%s", get_param_name(param), key.c_str()));
for (auto i : inputs) {
add_input({i});
inp_cn.insert(i->comp_node());
inp_dev.insert(
CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device);
// set comp node
const auto& cns = config.comp_node();
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
if (cns.size() > 0) {
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]);
} else {
output(0)->comp_node(inputs[0]->comp_node());
}
mgb_assert(
inp_dev.size() == inputs.size(),
"CollectiveComm inputs should not contain duplicated input device");
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);
// set debug flag
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
m_debug_mode = true;
}
// deduplication
add_equivalence_component<PODHash<Param>>(&m_param);
add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册