From 1bce857cb8ab64ac3060a77c14fd9dac29a6baaf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 15 Aug 2020 17:40:03 +0800 Subject: [PATCH] fix(mgb/opr-mm): use comp_node of config as default in CollectiveComm GitOrigin-RevId: 6b43c9fc93a5bdcffa12d81179c1d74d6f96ce56 --- src/opr-mm/impl/collective_comm.cpp | 110 +++++----------------------- 1 file changed, 18 insertions(+), 92 deletions(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index efb7a75cf..8305b1880 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -107,27 +107,9 @@ protected: } } - static void add_output_var_all2all(CollectiveComm* opr) { - mgb_assert(opr->nr_devices() >= 2); - auto pname = get_param_name(opr->param()); - // sublinear would setup opr->config if inputs.size() is 1, - // bypass this situation - mgb_assert( - !opr->config().has_comp_node_set() || opr->input().size() == 1, - "comp node should not be set in %s mode", pname); - for (auto i : opr->input()) { - opr->add_output(ssprintf("%s:%s", pname, i->cname())) - ->comp_node(i->comp_node()); - } - } - public: virtual ~ModeTrait() = default; - //! add output var for the opr - virtual void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet& inp_cn) = 0; - /*! * \brief the vars on whose comp node the computing should be performed * if None, output vars would be used @@ -188,11 +170,6 @@ public: }; class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { }; class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -292,11 +264,6 @@ protected: class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm*, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet& inp_cn) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { }; class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - if (opr->input().size() > 0) { - add_output_var_all2all(opr); - return; - } - - const auto& cns = opr->config().comp_node(); - mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size()); - auto pname = get_param_name(opr->param()); - opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]); - } - void get_output_var_shape(const CollectiveComm*, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait { }; class CollectiveComm::ModeTrait::GATHER : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { }; class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - if (opr->input().size() > 0) { - add_output_var_all2all(opr); - return; - } - - const auto& cns = opr->config().comp_node(); - mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size()); - auto pname = get_param_name(opr->param()); - opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait { }; class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { - void add_output_var(CollectiveComm* opr, - const CompNode::UnorderedSet&) override { - add_output_var_all2all(opr); - } - void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) override { @@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm( m_key(key), m_dev_buffers(dev_buffer_arr), m_disable{disable} { - for (auto i : inputs) { - mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA, - "CollectiveComm currectly only supports CUDA"); - } - for (auto i : config.comp_node()) { - mgb_assert(i.device_type() == CompNode::DeviceType::CUDA, + // add input + mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); + if (inputs.size() > 0) { + mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA, "CollectiveComm currectly only supports CUDA"); + add_input({inputs[0]}); } - CompNode::UnorderedSet inp_cn; - ThinHashSet inp_dev; + // add output + add_output(ssprintf("%s:%s", get_param_name(param), key.c_str())); - for (auto i : inputs) { - add_input({i}); - inp_cn.insert(i->comp_node()); - inp_dev.insert( - CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device); + // set comp node + const auto& cns = config.comp_node(); + mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); + if (cns.size() > 0) { + mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA, + "CollectiveComm currectly only supports CUDA"); + output(0)->comp_node(cns[0]); + } else { + output(0)->comp_node(inputs[0]->comp_node()); } - mgb_assert( - inp_dev.size() == inputs.size(), - "CollectiveComm inputs should not contain duplicated input device"); - - ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); + // set debug flag const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { m_debug_mode = true; } + // deduplication add_equivalence_component>(&m_param); add_equivalence_component>(&m_nr_devices); m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest(); -- GitLab