diff --git a/python_module/megengine/_internal/comp_graph_tools.py b/python_module/megengine/_internal/comp_graph_tools.py index 5ef32bd869627ca954cddd3115e7956eec06c397..5777d7d05fa092e78b67dbb8870b2f1969e9be85 100644 --- a/python_module/megengine/_internal/comp_graph_tools.py +++ b/python_module/megengine/_internal/comp_graph_tools.py @@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap): repl_dst_vec.push_back(j) return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) + + +def set_priority_to_id(dest_vars): + """For all oprs in the subgraph constructed by dest_vars + set its priority to id if its original priority is zero + :param dest_vars: target vars representing the graph + """ + dest_vec = _mgb._VectorSymbolVar() + for i in dest_vars: + assert isinstance(i, _mgb.SymbolVar) + dest_vec.push_back(i) + _mgb._set_priority_to_id(dest_vec) diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index 37a30759465f6fd61cae6cff620849aca2d2c206..a2ce8238c66de55fa0f617cd795a7c681ca365df 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -84,6 +84,8 @@ class trace: :param log_level: Log level. :param sublinear_memory_config: Configuration for sublinear memory optimization. If not None, it enables sublinear memory optimization with given setting. + :param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. + If not None, multiple gradients will be packed and synchronized together :param profiling: Whether to profile compiled trace. Default: False """ @@ -107,6 +109,7 @@ class trace: opt_level: int = None, log_level: int = None, sublinear_memory_config: SublinearMemoryConfig = None, + allreduce_pack_max_size: int = None, profiling: bool = False ): self.__wrapped__ = func @@ -114,6 +117,7 @@ class trace: self._graph_opt_level = opt_level self._log_level = log_level self._sublinear_memory_config = sublinear_memory_config + self._allreduce_pack_max_size = allreduce_pack_max_size self._status = self._UNSTARTED self._args = None self._kwargs = None @@ -313,6 +317,9 @@ class trace: "sublinear_mem_cofig.num_worker", self._sublinear_memory_config.num_worker, ) + # pack allreduce + if self._allreduce_pack_max_size is not None: + cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) # profile if self._profiling: self._profiler = CompGraphProfiler(cg) @@ -391,6 +398,7 @@ class trace: outputs = [outputs] # _run_wrapped has checked validity of outputs self._sym_outputs = tuple(i._symvar for i in outputs) + mgb.comp_graph_tools.set_priority_to_id(self._outspec) self._compiled_func = graph.get_default_graph().compile(None, self._outspec) def trace(self, *args: Tensor, **kwargs): diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index a6c9a5e2f68a3d03bab2a6453addbc461dff03ee..86c02c92b37a1704baaabeaffff9cfc7f6be7363 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): :param loss: The obtained loss tensor """ rst = [] - priority = 0 params = [] for group in self.param_groups: for param in group["params"]: @@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta): for param, grad in zip(params, grads): if is_distributed(): - priority += 1 - with opr_priority_scope(cg, -priority): - # all_reduce_mean + with opr_priority_scope(cg, -(2 ** 30)): + # always run all_reduce_mean first except add_update grad = ( all_reduce_sum(grad, "grad_" + str(get_group_id())) / get_world_size() ) - with opr_priority_scope(cg, (1 << 30) - priority): + with opr_priority_scope(cg, -(2 ** 31)): + # always run add_update first grad_update = add_update(param.grad, grad) else: grad_update = add_update(param.grad, grad) diff --git a/python_module/src/cpp/megbrain_config.cpp b/python_module/src/cpp/megbrain_config.cpp index b6c70da1e4fe611cfff6aae2e473c8c50f53cb36..4ff7026bc13fab707c547fd3649c40d2443ccdc0 100644 --- a/python_module/src/cpp/megbrain_config.cpp +++ b/python_module/src/cpp/megbrain_config.cpp @@ -66,6 +66,8 @@ bool _config::set_comp_graph_option( SET_CG_OPTION(graph_opt.jit); SET_CG_OPTION(graph_opt.tensorrt); SET_CG_OPTION(graph_opt_level); + SET_CG_OPTION(allreduce_pack_max_size); + SET_CG_OPTION(allreduce_pack_ignore_first); SET_CG_OPTION(var_sanity_check_first_run); SET_CG_OPTION(no_profiling_on_shape_change); SET_CG_OPTION(allocate_static_mem_after_graph_compile); diff --git a/python_module/src/swig/comp_graph_tools.i b/python_module/src/swig/comp_graph_tools.i index 26253fa667f347f38a7cbbb5485ba8d5db9779ad..7f6262df96c010582a9b396cc78a45a329675143 100644 --- a/python_module/src/swig/comp_graph_tools.i +++ b/python_module/src/swig/comp_graph_tools.i @@ -1,3 +1,7 @@ +%{ +#include "megbrain/gopt/framework.h" +%} + %inline { SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { @@ -35,5 +39,17 @@ } return mgb::cg::replace_oprs(vars, oprmap); } + + void _set_priority_to_id(const SymbolVarArray& dest_vars) { + auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { + if (opr->node_prop().attribute().priority == 0) { + opr->node_prop().attribute().priority = opr->id(); + } + }; + mgb::cg::DepOprIter dep_iter{on_opr}; + for (const SymbolVar& var : dest_vars) { + dep_iter.add(var); + } + } } // vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index b4eded13cdb955ee22a2be295c520ff37784eba8..0058b3da83c43b004429636cffc4e5c1f5eb94c1 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( optimizer.verbosity(options().log_level); optimizer.enable_check_result(options().graph_opt_level < 0); if (sopr_stat.has_virtual_grad) { - if (need_opt) + if (need_opt) { +#if MGB_ENABLE_OPR_MM + optimizer.add_pass(); +#endif optimizer.add_preset_passes(false, nullptr, &options()); + } optimizer.add_pass(); } - if (need_opt) + if (need_opt) { optimizer.add_preset_passes(true, nullptr, &options()); +#if MGB_ENABLE_OPR_MM + if (sopr_stat.has_virtual_grad) { + optimizer.add_pass(); + } +#endif + } optimizer.apply_inplace(dest_vars); } #endif diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 3b33175f58a3559ff43a244249e38fbd81219017..c5bc9c21a4675085ad849ecc7dd08f028809e80d 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -327,6 +327,18 @@ class ComputingGraph : public std::enable_shared_from_this, */ int16_t graph_opt_level = 2; + /*! + * max size of allreduce packs in MB + * set this option to zero to disable PackAllReducePass + */ + int16_t allreduce_pack_max_size = 0; + + /*! + * do not pack the first n allreduces + * PackAllReducePass disabled if allreduce_pack_max_size is zero + */ + int16_t allreduce_pack_ignore_first = 2; + /*! * set logging level, larger number means more verbose * 0: no log info diff --git a/src/core/include/megbrain/graph/helper.h b/src/core/include/megbrain/graph/helper.h index 189006840026a856eb8a6d5078507ae4f1963900..e799de966e4c69a5bb3a544025b1f69d4d15f267 100644 --- a/src/core/include/megbrain/graph/helper.h +++ b/src/core/include/megbrain/graph/helper.h @@ -183,7 +183,6 @@ SymbolVarArray replace_oprs( SymbolVarArray replace_vars_comp_graph( const SymbolVarArray &dest, ComputingGraph* new_graph); - SymbolVarArray find_h2d(const SymbolVarArray& dest); /*! diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index 6f066e2dd5aaf8ea1d132324bc6cc245fefeab39..f5dce790b87d3c3b9bfabe449fddcc522e4ad493 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -17,6 +17,7 @@ #include "megbrain/opr/utility.h" #include "megbrain/serialization/serializer.h" #include "megbrain/serialization/opr_shallow_copy.h" +#include "../../core/impl/graph/cg_impl.h" using namespace mgb; using namespace gopt; @@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { rewriter.apply_inplace(); } +#if MGB_ENABLE_OPR_MM +#include "megbrain/opr/collective_comm.h" + +/* ======================= PackAllReduceScanPass ====================== */ + +const char* PackAllReduceScanPass::name() const { + return "pack_allreduce_scan"; +} + +void PackAllReduceScanPass::apply(OptState& opt) const { + auto comp_graph = opt.graph().comp_graph(); + if (comp_graph->options().allreduce_pack_max_size == 0) return; + auto cb_scan = [this] (OperatorNodeBase* opr) { + if (check_pattern(opr)) { + auto& comm = opr->cast_final_safe(); + VarNode* target = comm.input(0)->owner_opr()->input(0); + // only pack allreduces of grads of the same target + // in case two allreduces depend on each other + size_t id = target->id(); + uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest(); + comm.set_pack_hash(hash); + } + }; + opt.graph().iter(cb_scan); +} + +bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { + if (!opr->same_type()) return false; + auto& comm = opr->cast_final_safe(); + if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; + if (comm.input().size() != 1) return false; + + auto grad = comm.input(0)->owner_opr(); + if (!grad->same_type()) return false; + if (grad->input().size() != 2 or grad->output().size() != 1) return false; + + auto param = grad->input(1)->owner_opr(); + if (!param->same_type() and + !param->same_type()) return false; + if (param->input().size() != 0) return false; + + return true; +} + +/* ======================= PackAllReduceReplacePass ====================== */ + +const char* PackAllReduceReplacePass::name() const { + return "pack_allreduce_replace"; +} + +class PackAllReduceReplacePass::GroupInfo { +public: + GroupInfo(int _device, DType _dtype, + size_t _nr_devices, bool _is_root, int _rank, + std::shared_ptr _group_client, + const std::string& _backend); + + uint64_t hash(uint64_t extra) const; + + int device; + DType dtype; + size_t nr_devices; + bool is_root; + int rank; + std::shared_ptr group_client; + std::string backend; +}; + +PackAllReduceReplacePass::GroupInfo::GroupInfo( + int _device, DType _dtype, + size_t _nr_devices, bool _is_root, int _rank, + std::shared_ptr _group_client, + const std::string& _backend) : + device(_device), dtype(_dtype), + nr_devices(_nr_devices), is_root(_is_root), rank(_rank), + group_client(_group_client), backend(_backend) { +} + +uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const { + DTypeEnum ev = dtype.enumv(); + const std::string& server_addr = group_client->get_addr(); + return XXHash() + .update(&extra, sizeof(uint64_t)) + .update(&device, sizeof(int)) + .update(&ev, sizeof(DTypeEnum)) + .update(&nr_devices, sizeof(size_t)) + .update(&is_root, sizeof(bool)) + .update(&rank, sizeof(int)) + .update(server_addr.c_str(), server_addr.size()) + .update(backend.c_str(), backend.size()) + .digest(); +} + +uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr, + ThinHashMap>& group_info, + ThinHashMap& groups) { + // check CollectiveComm oprs that have been marked in PackAllReduceScanPass + if (!opr->same_type()) return 0; + opr::CollectiveComm& comm = opr->cast_final_safe(); + if (comm.pack_hash() == 0) return 0; // pack_hash not set + + VarNode* var = comm.input(0); + auto info = std::make_shared( + var->comp_node().locator().device, + var->dtype(), + comm.nr_devices(), + comm.is_root(), + comm.rank(), + comm.group_client(), + comm.backend() + ); + uint64_t hash = info->hash(comm.pack_hash()); + if (group_info.find(hash) == group_info.end()) { + group_info.emplace(hash, info); + } + groups[hash].push_back(opr); + return hash; +} + +void PackAllReduceReplacePass::divide_packs( + const ThinHashMap& groups, + ThinHashMap>& packs, + size_t max_size) { + cg::OprNodeArray pack; + size_t sum = 0; + for (auto it : groups) { + uint64_t hash = it.first; + const cg::OprNodeArray& group = it.second; + for (size_t i = 0; i < group.size(); i++) { + OperatorNodeBase* opr = group[i]; + VarNode* var = opr->input(0); + const TensorShape* shape = var->owner_graph() + ->static_infer_manager().infer_shape_fallible(var); + if (shape == nullptr) continue; + pack.push_back(opr); + sum += var->dtype().size(shape->total_nr_elems()); + if (sum >= max_size) { + if (pack.size() > 1) packs[hash].push_back(pack); + pack.clear(); + sum = 0; + } + } + if (pack.size() > 1) packs[hash].push_back(pack); + pack.clear(); + sum = 0; + } +} + +void PackAllReduceReplacePass::insert_packed_oprs( + size_t pack_id, + const cg::OprNodeArray& pack, + std::shared_ptr info, + ThinHashMap& replace_map, int priority) { + // set priority + mgb_assert(pack.size() > 0); + auto graph = pack[0]->owner_graph(); + auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) { + event.opr->node_prop().attribute().priority = priority; + }; + auto handler = graph->event().register_receiver(on_opr_inserted); + + // flatten inputs and record shapes and partition + std::vector shapes; + SymbolVarArray flattens; + SymbolVarArray partition; + for (size_t i = 0; i < pack.size(); i++) { + VarNode* var = pack[i]->input(0); + auto shape = opr::GetVarShape::make(SymbolVar(var)); + shapes.push_back(shape); + SymbolVar flatten = SymbolVar(var).flatten(); + flattens.push_back(flatten); + partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0})); + } + + // concat + SymbolVar concat = opr::Concat::make(flattens, 0); + + // allreduce + std::string key = ssprintf("grad_pack_%zu", pack_id); + auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; + SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, + key, info->nr_devices, info->is_root, info->rank, + info->group_client, param, info->dtype, info->backend)[0]; + + // split according to recorded partition + SymbolVarArray splits = opr::Split::make(allreduce, + opr::Split::Options::make_partition(0, partition)); + + // reshape and insert results into replace_map + mgb_assert(pack.size() == splits.size()); + for (size_t i = 0; i < pack.size(); i++) { + VarNode* reshape = splits[i].reshape(shapes[i]).node(); + replace_map[pack[i]->output(0)] = reshape; + } +} + +void PackAllReduceReplacePass::apply(OptState& opt) const { + // get graph options + auto comp_graph = opt.graph().comp_graph(); + size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; + size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first; + if (max_size == 0) return; + + // get topo order + auto& topo_sorter = static_cast(comp_graph)->topo_sorter(); + cg::CompSeqExtraInfo extra_info; + VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars()); + const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints); + topo_sorter.restore_opr_prop(); + + // collect allreduce groups from topo sequence + ThinHashMap> group_info; + ThinHashMap groups; + for (size_t i = 0; i < seq->size(); i++) { + if (seq->at(i)->same_type()) { + // ignore the first several allreduces + if (ignore_first > 0) { + --ignore_first; + } else { + collect_groups(seq->at(i), group_info, groups); + } + } + } + + // divide groups into packs + ThinHashMap> packs; + divide_packs(groups, packs, max_size); + + // make sure that oprs inserted in this pass (reshape, concat, allreduce, + // split, reshape) have higher priority than existing operators + int priority = -seq->size() - 100; + + // insert packed operators and generate replace_map + ThinHashMap replace_map; + size_t pack_id = 0; + for (auto it : packs) { + uint64_t hash = it.first; + for (auto pack : it.second) { + opt.call_with_opr(pack[0], [&]() { + insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority); + }, OprPropertyFlag::NONE); + pack_id += 1; + } + } + + // replace vars + auto rewriter = opt.graph().make_rewriter(); + auto cb_replace = [&](OperatorNodeBase* opr) { + for (auto i : opr->input()) { + auto iter = replace_map.find(i); + if (iter != replace_map.end()) { + rewriter.replace_var(i, iter->second, nullptr); + } + } + rewriter.auto_replace_outputs(opr); + }; + opt.graph().iter(cb_replace); + rewriter.apply_inplace(); +} + +#else + +/* ======================= PackAllReduceScanPass ====================== */ + +const char* PackAllReduceScanPass::name() const { + return "pack_allreduce_scan"; +} + +void PackAllReduceScanPass::apply(OptState& opt) const { +} + +bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { + return true; +} + +/* ======================= PackAllReduceReplacePass ====================== */ + +const char* PackAllReduceReplacePass::name() const { + return "pack_allreduce_replace"; +} + +void PackAllReduceReplacePass::apply(OptState& opt) const {} + +uint64_t PackAllReduceReplacePass::collect_groups( + OperatorNodeBase* opr, + ThinHashMap>& group_info, + ThinHashMap& groups) { + return 0; +} + +void PackAllReduceReplacePass::divide_packs( + const ThinHashMap& groups, + ThinHashMap>& packs, + size_t max_size) { +} + +void PackAllReduceReplacePass::insert_packed_oprs( + size_t pack_id, + const cg::OprNodeArray& pack, + std::shared_ptr info, + ThinHashMap& replace_map, int priority) { +} + +#endif // MGB_ENABLE_OPR_MM + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/misc.h b/src/gopt/include/megbrain/gopt/misc.h index 40d8068796923300cf09c9db8a939514477c1e44..e8a833b5b3d801764ae5ce30fdda730e24a2a5e8 100644 --- a/src/gopt/include/megbrain/gopt/misc.h +++ b/src/gopt/include/megbrain/gopt/misc.h @@ -11,6 +11,8 @@ #pragma once +#include + #include "megbrain/gopt/framework.h" namespace mgb { @@ -90,6 +92,45 @@ namespace gopt { void apply(OptState& opt) const override; }; + //! scan allreduces of param grads + class PackAllReduceScanPass final : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + + private: + // check pattern param -> grad -> allreduce + static bool check_pattern(OperatorNodeBase* opr); + }; + + //! pack allreduces of param grads + class PackAllReduceReplacePass final : public Pass { + public: + class GroupInfo; + + const char* name() const override; + void apply(OptState& opt) const override; + + // collect allreduces and divide into groups + static uint64_t collect_groups( + OperatorNodeBase* opr, + ThinHashMap>& group_info, + ThinHashMap& groups); + + // divide groups into packs, max_size in MB + static void divide_packs( + const ThinHashMap& groups, + ThinHashMap>& packs, + size_t max_size); + + // insert packed operators and update replace_map + static void insert_packed_oprs( + size_t pack_id, + const cg::OprNodeArray& pack, + std::shared_ptr info, + ThinHashMap& replace_map, int priority); + }; + } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/misc.cpp b/src/gopt/test/misc.cpp index fcf7fb271ed5165fbf989235da0f71201c6284cf..ba0637b5b4a216e2f78f8f9ada3b2ff72ecd37b4 100644 --- a/src/gopt/test/misc.cpp +++ b/src/gopt/test/misc.cpp @@ -14,6 +14,7 @@ #include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/misc.h" #include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/blas.h" #include "megbrain/opr/cond.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" @@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { check(x_q8_q8, x_q8_fp32_q8_); } +#if MGB_ENABLE_OPR_MM +#include "megbrain/opr/collective_comm.h" +#include "../../opr-mm/test/mock_client.h" + +TEST_PASS(PackAllReduceScanPass, Basic) { + auto graph = ComputingGraph::make(); + graph->options().allreduce_pack_max_size = 5000; + + auto client = std::make_shared(); + auto cn = CompNode::load("gpux"); + + auto dev_x0 = std::make_shared(cn, TensorShape{3, 5}); + auto dev_x1 = std::make_shared(cn, TensorShape{4, 6}); + auto dev_y0 = std::make_shared(cn, TensorShape{1}); + auto dev_y1 = std::make_shared(cn, TensorShape{1}); + + auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0); + auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1); + auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0); + auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1); + + auto grad0 = opr::VirtualGrad::make(y0, x0); + auto grad1 = opr::VirtualGrad::make(y0, x1); + auto grad2 = opr::VirtualGrad::make(y1, x0); + auto grad3 = opr::VirtualGrad::make(y1, x1); + + auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; + auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), + "grad0", 2, 0, 0, client, mode)[0]; + auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), + "grad1", 2, 0, 0, client, mode)[0]; + auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), + "grad2", 2, 0, 0, client, mode)[0]; + auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), + "grad3", 2, 0, 0, client, mode)[0]; + + gopt::GraphOptimizer() + .add_pass() + .apply({{comm0, comm1, comm2, comm3}}); + + auto get_hash = [] (const SymbolVar& symvar) { + cg::OperatorNodeBase* opr = symvar.node()->owner_opr(); + return opr->cast_final_safe().pack_hash(); + }; + uint64_t hash0 = get_hash(comm0); + uint64_t hash1 = get_hash(comm1); + uint64_t hash2 = get_hash(comm2); + uint64_t hash3 = get_hash(comm3); + + ASSERT_EQ(hash0, hash1); + ASSERT_EQ(hash2, hash3); + ASSERT_NE(hash0, hash2); +} + +TEST_PASS(PackAllReduceReplacePass, CollectGroups) { + REQUIRE_GPU(2); + auto cns = load_multiple_xpus(2); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 2; + + auto cli0 = std::make_shared("mock_addr0"); + auto cli1 = std::make_shared("mock_addr1"); + + using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; + ThinHashMap> group_info; + ThinHashMap groups; + + auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt, + std::shared_ptr client, uint64_t extra_hash) { + auto dev0 = std::make_shared(cn, shape, dt); + auto wrt = opr::SharedDeviceTensor::make(*graph, dev0); + + auto dev1 = std::make_shared(cn, TensorShape{1}, dt); + auto target = opr::SharedDeviceTensor::make(*graph, dev1); + + auto grad = opr::VirtualGrad::make(target, wrt); + + auto comm = opr::CollectiveComm::make( + {grad}, graph.get(), "key", 2, 0, 0, client, + opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] + .node()->owner_opr(); + + comm->cast_final_safe().set_pack_hash(extra_hash); + + return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups); + }; + + uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1); + uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same + uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node + uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype + uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client + uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash + + ASSERT_EQ(hash0, hash1); + + std::set s; + s.insert(hash0); + s.insert(hash1); + s.insert(hash2); + s.insert(hash3); + s.insert(hash4); + s.insert(hash5); + ASSERT_EQ(5, s.size()); + + ASSERT_EQ(1, group_info.count(hash0)); + ASSERT_EQ(1, group_info.count(hash1)); + ASSERT_EQ(1, group_info.count(hash2)); + ASSERT_EQ(1, group_info.count(hash3)); + ASSERT_EQ(1, group_info.count(hash4)); + ASSERT_EQ(1, group_info.count(hash5)); + + ASSERT_EQ(2, groups[hash0].size()); + ASSERT_EQ(2, groups[hash1].size()); + ASSERT_EQ(1, groups[hash2].size()); + ASSERT_EQ(1, groups[hash3].size()); + ASSERT_EQ(1, groups[hash4].size()); + ASSERT_EQ(1, groups[hash5].size()); +} + +TEST_PASS(PackAllReduceReplacePass, DividePacks) { + auto cn = CompNode::load("gpux"); + auto graph = ComputingGraph::make(); + auto client = std::make_shared(); + auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; + + ThinHashMap groups; + ThinHashMap> packs; + + auto insert_opr = [&] (size_t size) { + auto dev = std::make_shared(cn, TensorShape{size / sizeof(float)}); + auto sd = opr::SharedDeviceTensor::make(*graph, dev); + auto symvar = opr::CollectiveComm::make({sd}, graph.get(), + "key", 2, 0, 0, client, mode)[0]; + auto opr = symvar.node()->owner_opr(); + auto& comm = opr->cast_final_safe(); + comm.set_pack_hash(1); + return opr; + }; + + auto pack_size = [&] (cg::OprNodeArray& pack) { + size_t sum = 0; + for (size_t i = 0; i < pack.size(); i++) { + auto var = pack[i]->input(0); + sum += var->dtype().size(var->shape().total_nr_elems()); + } + return sum; + }; + + groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100 + groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 + groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100 + groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 + groups[0].push_back(insert_opr(500)); // group0, pack1, size=800 + groups[0].push_back(insert_opr(200)); // group0, pack1, size=800 + groups[0].push_back(insert_opr(100)); // group0, pack1, size=800 + + groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 + groups[1].push_back(insert_opr(400)); // group1, pack0, size=900 + groups[1].push_back(insert_opr(300)); // group1, pack0, size=900 + groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 + + gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000); + + ASSERT_EQ(2, packs.size()); + + ASSERT_EQ(2, packs[0].size()); + ASSERT_EQ(4, packs[0][0].size()); + ASSERT_EQ(1100, pack_size(packs[0][0])); + ASSERT_EQ(3, packs[0][1].size()); + ASSERT_EQ(800, pack_size(packs[0][1])); + + ASSERT_EQ(1, packs[1].size()); + ASSERT_EQ(4, packs[1][0].size()); + ASSERT_EQ(900, pack_size(packs[1][0])); +} + +TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { + auto cn = CompNode::load("gpux"); + auto graph = ComputingGraph::make(); + auto client = std::make_shared(); + auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; + + size_t nr_devices = 2; + uint32_t rank = 0; + uint32_t root = 0; + + using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; + ThinHashMap> group_info; + ThinHashMap groups; + + auto insert_opr = [&] (const TensorShape& shape) { + auto dev = std::make_shared(cn, shape); + auto sd = opr::SharedDeviceTensor::make(*graph, dev); + auto symvar = opr::CollectiveComm::make({sd}, graph.get(), + "key", nr_devices, rank, root, client, mode)[0]; + auto opr = symvar.node()->owner_opr(); + auto& comm = opr->cast_final_safe(); + comm.set_pack_hash(1); + gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups); + return symvar; + }; + + auto shape_x = TensorShape{100, 200}; + auto shape_y = TensorShape{200, 400}; + + auto x = insert_opr(shape_x); + auto y = insert_opr(shape_y); + + ASSERT_EQ(1, group_info.size()); + ASSERT_EQ(1, groups.size()); + auto info = group_info.begin()->second; + auto pack = groups.begin()->second; + size_t pack_id = 0; + ThinHashMap replace_map; + gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1); + + auto grad_x = SymbolVar(x.node()->owner_opr()->input(0)); + auto grad_y = SymbolVar(y.node()->owner_opr()->input(0)); + + auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); + + std::string key = ssprintf("grad_pack_%zu", pack_id); + auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), + key, nr_devices, rank, root, client, mode)[0]; + + std::vector partition; + partition.push_back(shape_x.total_nr_elems()); + partition.push_back(shape_y.total_nr_elems()); + auto splits = opr::Split::make(allreduce, + opr::Split::Options::make_partition(allreduce, 0, partition)); + + ASSERT_EQ(2, splits.size()); + auto dest_x = splits[0].reshape(shape_x); + auto dest_y = splits[1].reshape(shape_y); + + ASSERT_EQ(2, replace_map.size()); + + ASSERT_TRUE(replace_map.count(x.node()) > 0); + ASSERT_EQ(replace_map.at(x.node()), dest_x.node()); + + ASSERT_TRUE(replace_map.count(y.node()) > 0); + ASSERT_EQ(replace_map.at(y.node()), dest_y.node()); +} + +TEST_PASS(PackAllReduceReplacePass, Equivalence) { + REQUIRE_GPU(2); + auto cns = load_multiple_xpus(2); + auto client = std::make_shared(); + + auto build_graph = [&] (uint32_t rank, std::shared_ptr graph, + SymbolVarArray& array) { + HostTensorGenerator<> gen; + auto cn = cns[rank]; + auto host_x = gen({1, 1000}); + auto host_y = gen({1000, 1}); + + auto dev_x = std::make_shared(cn); + auto dev_y = std::make_shared(cn); + + dev_x->copy_from(*host_x).sync(); + dev_y->copy_from(*host_y).sync(); + + auto x = opr::SharedDeviceTensor::make(*graph, dev_x); + auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y); + auto loss = opr::MatrixMul::make(x, y).flatten(); + + auto grad_x = opr::VirtualGrad::make(loss, x); + auto grad_y = opr::VirtualGrad::make(loss, y); + + using Mode = opr::CollectiveComm::Param::Mode; + bool is_root = (rank == 0); + auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), + "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; + auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), + "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; + + graph->options().allreduce_pack_max_size = 5000; + graph->options().allreduce_pack_ignore_first = 0; + + auto dest_vars = gopt::GraphOptimizer{} + .add_pass() + .add_pass() + .apply({{reduced_x, reduced_y}}).endpoint_vars(); + + array.emplace_back(reduced_x); + array.emplace_back(reduced_y); + array.emplace_back(dest_vars[0]); + array.emplace_back(dest_vars[1]); + }; + + auto run = [&] (uint32_t rank) { + auto graph = ComputingGraph::make(); + SymbolVarArray array; + build_graph(rank, graph, array); + + HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1; + + graph->options().allreduce_pack_max_size = 0; + auto func = graph->compile({make_callback_copy(array[0], host_reduced_x), + make_callback_copy(array[1], host_reduced_y), + make_callback_copy(array[2], host_dest_0), + make_callback_copy(array[3], host_dest_1)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0); + MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1); + }; + + std::thread t0(run, 0); + std::thread t1(run, 1); + + t0.join(); + t1.join(); +} + +#endif // MGB_ENABLE_OPR_MM + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index d735bb1cb45171d59954e2af5c102e84547fa8bb..d4db9f56af0897d237b83d0dd8bed41ff319dad2 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -461,16 +461,7 @@ void CollectiveComm::opr_register() { m_rank = reg_info.rank; m_root = reg_info.root_rank; - MegRayCommunicatorBuilder* builder; - - { - static std::mutex user_data_mtx; - std::unique_lock lk(user_data_mtx); - builder = owner_graph()->options().user_data - .get_user_data_or_create(); - } - - m_megray_comm = builder->get_megray_comm( + m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client); @@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - return opr::CollectiveComm::make( + auto new_opr = CollectiveComm::make( to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), opr.group_client(), opr.dev_buffers(), opr.param(), opr.dtype(), opr.backend(), config)[0] .node() ->owner_opr(); + new_opr->cast_final_safe().set_pack_hash(opr.pack_hash()); + return new_opr; } MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 8fea82254fce380ce1a19fb7ba570ada816383a5..5cc588493440dc7d521c5c01952faa87ad697f5b 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() { auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, comp_node.get_uid()); - auto megray_comm_builder = - owner_graph() - ->options() - .user_data - .get_user_data_or_create(); - - m_megray_comm = megray_comm_builder->get_megray_comm( + m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); m_init = true; } @@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() { m_peer.key, 2, false, 1, comp_node.get_uid()); - auto megray_comm_builder = - owner_graph() - ->options() - .user_data - .get_user_data_or_create(); - - m_megray_comm = megray_comm_builder->get_megray_comm( + m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); m_init = true; } diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index 2465f7f5427b7cf46ec83966645e70a888e1362c..c96f49daf2a08820d8433bd9a11edaba0d2e3af1 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -14,8 +14,8 @@ using namespace mgb; using namespace opr; -bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr& comm) { - std::unique_lock lk(m_mtx); +bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr& comm) { + std::unique_lock lk(m_map_mtx); auto it = m_megray_comms.find(hash); if (it != m_megray_comms.end()) { comm = it->second; @@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr comm) { - std::unique_lock lk(m_mtx); + std::unique_lock lk(m_map_mtx); m_megray_comms.emplace(hash, comm); } -std::shared_ptr MegRayCommunicatorBuilder::get_megray_comm( +std::shared_ptr MegRayCommBuilder::get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client) { + { + // singleton pattern + std::unique_lock lk(sm_instance_mtx); + if (sm_instance == nullptr) { + sm_instance = new MegRayCommBuilder(); + } + } + std::shared_ptr comm; - if (!find(hash, comm)) { + if (!sm_instance->find(hash, comm)) { comm = MegRay::get_communicator(size, rank, backend); auto uid = comm->get_uid(); auto uids = group_client->gather_uid(uid, key, size, rank); mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); - emplace(hash, comm); + sm_instance->emplace(hash, comm); } return comm; } -MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); +MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; + +std::mutex MegRayCommBuilder::sm_instance_mtx; // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/include/megbrain/opr/collective_comm.h b/src/opr-mm/include/megbrain/opr/collective_comm.h index 2c7a1dfd0c8f7e14d42dd0d7102f284f200e2309..9ec9a5b44f7b5d25299c763c07cf4b0933071807 100644 --- a/src/opr-mm/include/megbrain/opr/collective_comm.h +++ b/src/opr-mm/include/megbrain/opr/collective_comm.h @@ -81,6 +81,10 @@ public: return m_group_client; } + void set_pack_hash(uint64_t hash) { m_pack_hash = hash; } + + uint64_t pack_hash() const { return m_pack_hash; } + std::shared_ptr megray_ctx() const { return m_megray_ctx; } @@ -123,6 +127,9 @@ private: // whose shape infer should be disabled *during* static infer phase. bool m_enable_shape_infer = false; + //! set in PackAllReduceScanPass and used in PackAllReduceReplacePass + uint64_t m_pack_hash = 0; + std::shared_ptr m_megray_ctx; std::shared_ptr m_megray_comm; bool m_init = false; diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index 9e87a89d53cd60cba70663cafd72690ff4ae7f66..c8a172e9a2055011f2ba7ff8645ed0ca7b9a466c 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -126,6 +126,8 @@ class GroupClient { virtual ~GroupClient() = default; public: + virtual const std::string& get_addr() const = 0; + virtual GroupManager::RegisterInfo opr_register(const std::string& key, size_t nr_devices, bool is_root, int rank, diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index 255af039d71a1d1553998e500cb27d317276daf3..4e9117bd915f6c687256cdddeac30a82c63bedb9 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -23,18 +23,19 @@ namespace opr { /*! * gather MegRay unique ids and build communicator, use hash for deduplication */ -class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData { - MGB_TYPEINFO_OBJ_DECL; - +class MegRayCommBuilder { private: bool find(uint64_t hash, std::shared_ptr& comm); void emplace(uint64_t hash, std::shared_ptr comm); std::unordered_map> m_megray_comms; - std::mutex m_mtx; + std::mutex m_map_mtx; + + static MegRayCommBuilder* sm_instance; + static std::mutex sm_instance_mtx; public: - std::shared_ptr get_megray_comm( + static std::shared_ptr get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client); diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index ccc567d6bcddf7619e44d78ab333f32f1439548f..eaa33f90f05908cd06a4a478b2662971d9a3c3fa 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -47,7 +47,7 @@ public: uint32_t group_barrier(uint32_t size, uint32_t rank) override; - const std::string& get_addr() const { + const std::string& get_addr() const override { return m_addr; } diff --git a/src/opr-mm/test/collective_comm.cpp b/src/opr-mm/test/collective_comm.cpp index 70910a576fb629a5363cdcbebcb911eec5067238..b24744c11450e14f7af162a3046a985a12f9f753 100644 --- a/src/opr-mm/test/collective_comm.cpp +++ b/src/opr-mm/test/collective_comm.cpp @@ -17,11 +17,10 @@ #include "megbrain/opr/utility.h" #include "megbrain/test/helper.h" #include "megbrain/graph.h" +#include "mock_client.h" using namespace mgb; -namespace { - using Mode = opr::CollectiveComm::Param::Mode; SymbolVar make_all_reduce_output(const Mode mode, @@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) { rdc, opr::Split::Options::make_average(0, inputs.size())); } -class MockGroupClient final : public opr::GroupClient { - public: - ~MockGroupClient() override = default; - - opr::GroupManager::RegisterInfo opr_register(const std::string& key, - size_t nr_devices, - bool is_root, int rank, - uintptr_t stream) { - return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); - } - - std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) { - return m_mgr.gather_uid(uid, key, size, rank); - } - - void set_output_shape(const std::string& key, - const TensorShape& shape) override { - m_mgr.set_output_shape(key, shape); - } - - TensorShape get_output_shape(const std::string& key) override { - return m_mgr.get_output_shape(key); - } - - uint32_t group_barrier(uint32_t size, uint32_t rank) override { - return m_mgr.group_barrier(size, rank); - } - - private: - opr::GroupManager m_mgr; -}; - -} // namespace - TEST(TestOprCollectiveComm, AllReduce) { REQUIRE_GPU(2); @@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); @@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { auto graph0 = ComputingGraph::make(); @@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { HostTensorND host_y0, host_y1, host_y_expect; HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); @@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { HostTensorND host_out_grad0, host_out_grad1; HostTensorND host_out_grad0_expect, host_out_grad1_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); @@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { auto host_x1 = gen({8}); HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); @@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) { auto host_x0 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); @@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { auto host_x0 = gen({28, 28}); HostTensorND host_y0, host_y1; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); @@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; - auto client = std::make_shared(); + auto client = std::make_shared(); auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); diff --git a/src/opr-mm/test/io_remote.cpp b/src/opr-mm/test/io_remote.cpp index b79178095c4fdc208bbaeb0fd8361cdd399039cf..f4302dd203dbc3b651b60bb8276d08be1a1e4f1e 100644 --- a/src/opr-mm/test/io_remote.cpp +++ b/src/opr-mm/test/io_remote.cpp @@ -14,51 +14,14 @@ #include "megbrain/opr/utility.h" #include "megbrain/system.h" #include "megbrain/test/helper.h" +#include "mock_client.h" #include using namespace mgb; -using namespace opr; -namespace { - -class MockGroupClient final : public opr::GroupClient { - public: - ~MockGroupClient() override = default; - - opr::GroupManager::RegisterInfo opr_register(const std::string& key, - size_t nr_devices, - bool is_root, int rank, - uint64_t comp_node_hash) { - return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); - } - - std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) { - return m_mgr.gather_uid(uid, key, size, rank); - } - - void set_output_shape(const std::string& key, - const TensorShape& shape) override { - m_mgr.set_output_shape(key, shape); - } - - TensorShape get_output_shape(const std::string& key) override { - return m_mgr.get_output_shape(key); - } - - uint32_t group_barrier(uint32_t size, uint32_t rank) override { - return m_mgr.group_barrier(size, rank); - } - - private: - opr::GroupManager m_mgr; -}; - -const auto send_tag = RemoteIOBase::Type::SEND; -const auto recv_tag = RemoteIOBase::Type::RECV; - -} // anonymous namespace +const auto send_tag = opr::RemoteIOBase::Type::SEND; +const auto recv_tag = opr::RemoteIOBase::Type::RECV; TEST(TestOprIORemote, Identity) { REQUIRE_GPU(2); @@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) { auto host_x = gen({28, 28}); HostTensorND host_y; - auto client = std::make_shared(); + auto client = std::make_shared(); auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); @@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { HostTensorGenerator<> gen; auto host_x = gen({2, 3}, cns[1]); HostTensorND host_x_get; - auto client = std::make_shared(); + auto client = std::make_shared(); auto sender = [&]() { auto graph = ComputingGraph::make(); @@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { HostTensorGenerator<> gen; auto host_x = gen({2, 3}, cns[0]); HostTensorND host_x_get; - auto client = std::make_shared(); + auto client = std::make_shared(); auto sender = [&]() { sys::set_thread_name("sender"); @@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) { HostTensorGenerator<> gen; auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); HostTensorND host_z; - auto client = std::make_shared(); + auto client = std::make_shared(); auto sender = [&]() { auto graph = ComputingGraph::make(); @@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) { HostTensorGenerator<> gen; auto host_x = gen({2, 3}, cns[0]); HostTensorND host_gx, host_loss; - auto client = std::make_shared(); + auto client = std::make_shared(); auto sender = [&]() { sys::set_thread_name("sender"); diff --git a/src/opr-mm/test/mock_client.h b/src/opr-mm/test/mock_client.h new file mode 100644 index 0000000000000000000000000000000000000000..5ca014594978115c6baf393267be67a10c3ef295 --- /dev/null +++ b/src/opr-mm/test/mock_client.h @@ -0,0 +1,62 @@ +/** + * \file src/opr-mm/test/mock_client.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/group_manager.h" + +namespace mgb { +namespace test { + +class MockGroupClient final : public opr::GroupClient { + public: + using RegisterInfo = opr::GroupManager::RegisterInfo; + + MockGroupClient(const std::string& server_addr = "mock_addr") : + m_addr(server_addr) { + } + + ~MockGroupClient() override = default; + + const std::string& get_addr() const { + return m_addr; + } + + RegisterInfo opr_register(const std::string& key, size_t nr_devices, + bool is_root, int rank, uint64_t comp_node_hash) { + return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); + } + + std::vector gather_uid(const std::string& uid, + const std::string& key, uint32_t size, uint32_t rank) { + return m_mgr.gather_uid(uid, key, size, rank); + } + + void set_output_shape(const std::string& key, + const TensorShape& shape) override { + m_mgr.set_output_shape(key, shape); + } + + TensorShape get_output_shape(const std::string& key) override { + return m_mgr.get_output_shape(key); + } + + uint32_t group_barrier(uint32_t size, uint32_t rank) override { + return m_mgr.group_barrier(size, rank); + } + + private: + const std::string m_addr; + opr::GroupManager m_mgr; +}; + +} // namespace test +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}