提交 01092feb 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb): add PackAllReducePass

GitOrigin-RevId: 59c1b4539340f7bf1b9475255de665b1842accf1
上级 c7e6c658
......@@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap):
repl_dst_vec.push_back(j)
return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def set_priority_to_id(dest_vars):
"""For all oprs in the subgraph constructed by dest_vars
set its priority to id if its original priority is zero
:param dest_vars: target vars representing the graph
"""
dest_vec = _mgb._VectorSymbolVar()
for i in dest_vars:
assert isinstance(i, _mgb.SymbolVar)
dest_vec.push_back(i)
_mgb._set_priority_to_id(dest_vec)
......@@ -84,6 +84,8 @@ class trace:
:param log_level: Log level.
:param sublinear_memory_config: Configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
:param allreduce_pack_max_size: Maximum size of an allreduce pack in MB.
If not None, multiple gradients will be packed and synchronized together
:param profiling: Whether to profile compiled trace. Default: False
"""
......@@ -107,6 +109,7 @@ class trace:
opt_level: int = None,
log_level: int = None,
sublinear_memory_config: SublinearMemoryConfig = None,
allreduce_pack_max_size: int = None,
profiling: bool = False
):
self.__wrapped__ = func
......@@ -114,6 +117,7 @@ class trace:
self._graph_opt_level = opt_level
self._log_level = log_level
self._sublinear_memory_config = sublinear_memory_config
self._allreduce_pack_max_size = allreduce_pack_max_size
self._status = self._UNSTARTED
self._args = None
self._kwargs = None
......@@ -313,6 +317,9 @@ class trace:
"sublinear_mem_cofig.num_worker",
self._sublinear_memory_config.num_worker,
)
# pack allreduce
if self._allreduce_pack_max_size is not None:
cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size)
# profile
if self._profiling:
self._profiler = CompGraphProfiler(cg)
......@@ -391,6 +398,7 @@ class trace:
outputs = [outputs]
# _run_wrapped has checked validity of outputs
self._sym_outputs = tuple(i._symvar for i in outputs)
mgb.comp_graph_tools.set_priority_to_id(self._outspec)
self._compiled_func = graph.get_default_graph().compile(None, self._outspec)
def trace(self, *args: Tensor, **kwargs):
......
......@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta):
:param loss: The obtained loss tensor
"""
rst = []
priority = 0
params = []
for group in self.param_groups:
for param in group["params"]:
......@@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta):
for param, grad in zip(params, grads):
if is_distributed():
priority += 1
with opr_priority_scope(cg, -priority):
# all_reduce_mean
with opr_priority_scope(cg, -(2 ** 30)):
# always run all_reduce_mean first except add_update
grad = (
all_reduce_sum(grad, "grad_" + str(get_group_id()))
/ get_world_size()
)
with opr_priority_scope(cg, (1 << 30) - priority):
with opr_priority_scope(cg, -(2 ** 31)):
# always run add_update first
grad_update = add_update(param.grad, grad)
else:
grad_update = add_update(param.grad, grad)
......
......@@ -66,6 +66,8 @@ bool _config::set_comp_graph_option(
SET_CG_OPTION(graph_opt.jit);
SET_CG_OPTION(graph_opt.tensorrt);
SET_CG_OPTION(graph_opt_level);
SET_CG_OPTION(allreduce_pack_max_size);
SET_CG_OPTION(allreduce_pack_ignore_first);
SET_CG_OPTION(var_sanity_check_first_run);
SET_CG_OPTION(no_profiling_on_shape_change);
SET_CG_OPTION(allocate_static_mem_after_graph_compile);
......
%{
#include "megbrain/gopt/framework.h"
%}
%inline {
SymbolVarArray _get_owner_opr_inputs(SymbolVar var) {
......@@ -35,5 +39,17 @@
}
return mgb::cg::replace_oprs(vars, oprmap);
}
void _set_priority_to_id(const SymbolVarArray& dest_vars) {
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
if (opr->node_prop().attribute().priority == 0) {
opr->node_prop().attribute().priority = opr->id();
}
};
mgb::cg::DepOprIter dep_iter{on_opr};
for (const SymbolVar& var : dest_vars) {
dep_iter.add(var);
}
}
}
// vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}}
......@@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
optimizer.verbosity(options().log_level);
optimizer.enable_check_result(options().graph_opt_level < 0);
if (sopr_stat.has_virtual_grad) {
if (need_opt)
if (need_opt) {
#if MGB_ENABLE_OPR_MM
optimizer.add_pass<gopt::PackAllReduceScanPass>();
#endif
optimizer.add_preset_passes(false, nullptr, &options());
}
optimizer.add_pass<gopt::ExpandVirtualGradPass>();
}
if (need_opt)
if (need_opt) {
optimizer.add_preset_passes(true, nullptr, &options());
#if MGB_ENABLE_OPR_MM
if (sopr_stat.has_virtual_grad) {
optimizer.add_pass<gopt::PackAllReduceReplacePass>();
}
#endif
}
optimizer.apply_inplace(dest_vars);
}
#endif
......
......@@ -327,6 +327,18 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
*/
int16_t graph_opt_level = 2;
/*!
* max size of allreduce packs in MB
* set this option to zero to disable PackAllReducePass
*/
int16_t allreduce_pack_max_size = 0;
/*!
* do not pack the first n allreduces
* PackAllReducePass disabled if allreduce_pack_max_size is zero
*/
int16_t allreduce_pack_ignore_first = 2;
/*!
* set logging level, larger number means more verbose
* 0: no log info
......
......@@ -183,7 +183,6 @@ SymbolVarArray replace_oprs(
SymbolVarArray replace_vars_comp_graph(
const SymbolVarArray &dest, ComputingGraph* new_graph);
SymbolVarArray find_h2d(const SymbolVarArray& dest);
/*!
......
......@@ -17,6 +17,7 @@
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "../../core/impl/graph/cg_impl.h"
using namespace mgb;
using namespace gopt;
......@@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {
rewriter.apply_inplace();
}
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
/* ======================= PackAllReduceScanPass ====================== */
const char* PackAllReduceScanPass::name() const {
return "pack_allreduce_scan";
}
void PackAllReduceScanPass::apply(OptState& opt) const {
auto comp_graph = opt.graph().comp_graph();
if (comp_graph->options().allreduce_pack_max_size == 0) return;
auto cb_scan = [this] (OperatorNodeBase* opr) {
if (check_pattern(opr)) {
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
VarNode* target = comm.input(0)->owner_opr()->input(0);
// only pack allreduces of grads of the same target
// in case two allreduces depend on each other
size_t id = target->id();
uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest();
comm.set_pack_hash(hash);
}
};
opt.graph().iter(cb_scan);
}
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
if (!opr->same_type<opr::CollectiveComm>()) return false;
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false;
if (comm.input().size() != 1) return false;
auto grad = comm.input(0)->owner_opr();
if (!grad->same_type<opr::VirtualGrad>()) return false;
if (grad->input().size() != 2 or grad->output().size() != 1) return false;
auto param = grad->input(1)->owner_opr();
if (!param->same_type<opr::SharedDeviceTensor>() and
!param->same_type<opr::VolatileSharedDeviceTensor>()) return false;
if (param->input().size() != 0) return false;
return true;
}
/* ======================= PackAllReduceReplacePass ====================== */
const char* PackAllReduceReplacePass::name() const {
return "pack_allreduce_replace";
}
class PackAllReduceReplacePass::GroupInfo {
public:
GroupInfo(int _device, DType _dtype,
size_t _nr_devices, bool _is_root, int _rank,
std::shared_ptr<opr::GroupClient> _group_client,
const std::string& _backend);
uint64_t hash(uint64_t extra) const;
int device;
DType dtype;
size_t nr_devices;
bool is_root;
int rank;
std::shared_ptr<opr::GroupClient> group_client;
std::string backend;
};
PackAllReduceReplacePass::GroupInfo::GroupInfo(
int _device, DType _dtype,
size_t _nr_devices, bool _is_root, int _rank,
std::shared_ptr<opr::GroupClient> _group_client,
const std::string& _backend) :
device(_device), dtype(_dtype),
nr_devices(_nr_devices), is_root(_is_root), rank(_rank),
group_client(_group_client), backend(_backend) {
}
uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const {
DTypeEnum ev = dtype.enumv();
const std::string& server_addr = group_client->get_addr();
return XXHash()
.update(&extra, sizeof(uint64_t))
.update(&device, sizeof(int))
.update(&ev, sizeof(DTypeEnum))
.update(&nr_devices, sizeof(size_t))
.update(&is_root, sizeof(bool))
.update(&rank, sizeof(int))
.update(server_addr.c_str(), server_addr.size())
.update(backend.c_str(), backend.size())
.digest();
}
uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr,
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info,
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) {
// check CollectiveComm oprs that have been marked in PackAllReduceScanPass
if (!opr->same_type<opr::CollectiveComm>()) return 0;
opr::CollectiveComm& comm = opr->cast_final_safe<opr::CollectiveComm>();
if (comm.pack_hash() == 0) return 0; // pack_hash not set
VarNode* var = comm.input(0);
auto info = std::make_shared<GroupInfo>(
var->comp_node().locator().device,
var->dtype(),
comm.nr_devices(),
comm.is_root(),
comm.rank(),
comm.group_client(),
comm.backend()
);
uint64_t hash = info->hash(comm.pack_hash());
if (group_info.find(hash) == group_info.end()) {
group_info.emplace(hash, info);
}
groups[hash].push_back(opr);
return hash;
}
void PackAllReduceReplacePass::divide_packs(
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups,
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs,
size_t max_size) {
cg::OprNodeArray pack;
size_t sum = 0;
for (auto it : groups) {
uint64_t hash = it.first;
const cg::OprNodeArray& group = it.second;
for (size_t i = 0; i < group.size(); i++) {
OperatorNodeBase* opr = group[i];
VarNode* var = opr->input(0);
const TensorShape* shape = var->owner_graph()
->static_infer_manager().infer_shape_fallible(var);
if (shape == nullptr) continue;
pack.push_back(opr);
sum += var->dtype().size(shape->total_nr_elems());
if (sum >= max_size) {
if (pack.size() > 1) packs[hash].push_back(pack);
pack.clear();
sum = 0;
}
}
if (pack.size() > 1) packs[hash].push_back(pack);
pack.clear();
sum = 0;
}
}
void PackAllReduceReplacePass::insert_packed_oprs(
size_t pack_id,
const cg::OprNodeArray& pack,
std::shared_ptr<GroupInfo> info,
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) {
// set priority
mgb_assert(pack.size() > 0);
auto graph = pack[0]->owner_graph();
auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) {
event.opr->node_prop().attribute().priority = priority;
};
auto handler = graph->event().register_receiver<cg::event::OprInserted>(on_opr_inserted);
// flatten inputs and record shapes and partition
std::vector<SymbolVar> shapes;
SymbolVarArray flattens;
SymbolVarArray partition;
for (size_t i = 0; i < pack.size(); i++) {
VarNode* var = pack[i]->input(0);
auto shape = opr::GetVarShape::make(SymbolVar(var));
shapes.push_back(shape);
SymbolVar flatten = SymbolVar(var).flatten();
flattens.push_back(flatten);
partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0}));
}
// concat
SymbolVar concat = opr::Concat::make(flattens, 0);
// allreduce
std::string key = ssprintf("grad_pack_%zu", pack_id);
auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph,
key, info->nr_devices, info->is_root, info->rank,
info->group_client, param, info->dtype, info->backend)[0];
// split according to recorded partition
SymbolVarArray splits = opr::Split::make(allreduce,
opr::Split::Options::make_partition(0, partition));
// reshape and insert results into replace_map
mgb_assert(pack.size() == splits.size());
for (size_t i = 0; i < pack.size(); i++) {
VarNode* reshape = splits[i].reshape(shapes[i]).node();
replace_map[pack[i]->output(0)] = reshape;
}
}
void PackAllReduceReplacePass::apply(OptState& opt) const {
// get graph options
auto comp_graph = opt.graph().comp_graph();
size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024;
size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first;
if (max_size == 0) return;
// get topo order
auto& topo_sorter = static_cast<cg::ComputingGraphImpl*>(comp_graph)->topo_sorter();
cg::CompSeqExtraInfo extra_info;
VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars());
const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints);
topo_sorter.restore_opr_prop();
// collect allreduce groups from topo sequence
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
ThinHashMap<uint64_t, cg::OprNodeArray> groups;
for (size_t i = 0; i < seq->size(); i++) {
if (seq->at(i)->same_type<opr::CollectiveComm>()) {
// ignore the first several allreduces
if (ignore_first > 0) {
--ignore_first;
} else {
collect_groups(seq->at(i), group_info, groups);
}
}
}
// divide groups into packs
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs;
divide_packs(groups, packs, max_size);
// make sure that oprs inserted in this pass (reshape, concat, allreduce,
// split, reshape) have higher priority than existing operators
int priority = -seq->size() - 100;
// insert packed operators and generate replace_map
ThinHashMap<VarNode*, VarNode*> replace_map;
size_t pack_id = 0;
for (auto it : packs) {
uint64_t hash = it.first;
for (auto pack : it.second) {
opt.call_with_opr(pack[0], [&]() {
insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority);
}, OprPropertyFlag::NONE);
pack_id += 1;
}
}
// replace vars
auto rewriter = opt.graph().make_rewriter();
auto cb_replace = [&](OperatorNodeBase* opr) {
for (auto i : opr->input()) {
auto iter = replace_map.find(i);
if (iter != replace_map.end()) {
rewriter.replace_var(i, iter->second, nullptr);
}
}
rewriter.auto_replace_outputs(opr);
};
opt.graph().iter(cb_replace);
rewriter.apply_inplace();
}
#else
/* ======================= PackAllReduceScanPass ====================== */
const char* PackAllReduceScanPass::name() const {
return "pack_allreduce_scan";
}
void PackAllReduceScanPass::apply(OptState& opt) const {
}
bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
return true;
}
/* ======================= PackAllReduceReplacePass ====================== */
const char* PackAllReduceReplacePass::name() const {
return "pack_allreduce_replace";
}
void PackAllReduceReplacePass::apply(OptState& opt) const {}
uint64_t PackAllReduceReplacePass::collect_groups(
OperatorNodeBase* opr,
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info,
ThinHashMap<uint64_t, cg::OprNodeArray>& groups) {
return 0;
}
void PackAllReduceReplacePass::divide_packs(
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups,
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs,
size_t max_size) {
}
void PackAllReduceReplacePass::insert_packed_oprs(
size_t pack_id,
const cg::OprNodeArray& pack,
std::shared_ptr<GroupInfo> info,
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) {
}
#endif // MGB_ENABLE_OPR_MM
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -11,6 +11,8 @@
#pragma once
#include <vector>
#include "megbrain/gopt/framework.h"
namespace mgb {
......@@ -90,6 +92,45 @@ namespace gopt {
void apply(OptState& opt) const override;
};
//! scan allreduces of param grads
class PackAllReduceScanPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
private:
// check pattern param -> grad -> allreduce
static bool check_pattern(OperatorNodeBase* opr);
};
//! pack allreduces of param grads
class PackAllReduceReplacePass final : public Pass {
public:
class GroupInfo;
const char* name() const override;
void apply(OptState& opt) const override;
// collect allreduces and divide into groups
static uint64_t collect_groups(
OperatorNodeBase* opr,
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info,
ThinHashMap<uint64_t, cg::OprNodeArray>& groups);
// divide groups into packs, max_size in MB
static void divide_packs(
const ThinHashMap<uint64_t, cg::OprNodeArray>& groups,
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs,
size_t max_size);
// insert packed operators and update replace_map
static void insert_packed_oprs(
size_t pack_id,
const cg::OprNodeArray& pack,
std::shared_ptr<GroupInfo> info,
ThinHashMap<VarNode*, VarNode*>& replace_map, int priority);
};
} // namespace gopt
} // namespace mgb
......
......@@ -14,6 +14,7 @@
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
......@@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
check(x_q8_q8, x_q8_fp32_q8_);
}
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
#include "../../opr-mm/test/mock_client.h"
TEST_PASS(PackAllReduceScanPass, Basic) {
auto graph = ComputingGraph::make();
graph->options().allreduce_pack_max_size = 5000;
auto client = std::make_shared<test::MockGroupClient>();
auto cn = CompNode::load("gpux");
auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5});
auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6});
auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0);
auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1);
auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0);
auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1);
auto grad0 = opr::VirtualGrad::make(y0, x0);
auto grad1 = opr::VirtualGrad::make(y0, x1);
auto grad2 = opr::VirtualGrad::make(y1, x0);
auto grad3 = opr::VirtualGrad::make(y1, x1);
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(),
"grad0", 2, 0, 0, client, mode)[0];
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(),
"grad1", 2, 0, 0, client, mode)[0];
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(),
"grad2", 2, 0, 0, client, mode)[0];
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(),
"grad3", 2, 0, 0, client, mode)[0];
gopt::GraphOptimizer()
.add_pass<gopt::PackAllReduceScanPass>()
.apply({{comm0, comm1, comm2, comm3}});
auto get_hash = [] (const SymbolVar& symvar) {
cg::OperatorNodeBase* opr = symvar.node()->owner_opr();
return opr->cast_final_safe<opr::CollectiveComm>().pack_hash();
};
uint64_t hash0 = get_hash(comm0);
uint64_t hash1 = get_hash(comm1);
uint64_t hash2 = get_hash(comm2);
uint64_t hash3 = get_hash(comm3);
ASSERT_EQ(hash0, hash1);
ASSERT_EQ(hash2, hash3);
ASSERT_NE(hash0, hash2);
}
TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
REQUIRE_GPU(2);
auto cns = load_multiple_xpus(2);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 2;
auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0");
auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1");
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
ThinHashMap<uint64_t, cg::OprNodeArray> groups;
auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt,
std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) {
auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt);
auto wrt = opr::SharedDeviceTensor::make(*graph, dev0);
auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt);
auto target = opr::SharedDeviceTensor::make(*graph, dev1);
auto grad = opr::VirtualGrad::make(target, wrt);
auto comm = opr::CollectiveComm::make(
{grad}, graph.get(), "key", 2, 0, 0, client,
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
.node()->owner_opr();
comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash);
return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups);
};
uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1);
uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same
uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node
uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype
uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client
uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash
ASSERT_EQ(hash0, hash1);
std::set<uint64_t> s;
s.insert(hash0);
s.insert(hash1);
s.insert(hash2);
s.insert(hash3);
s.insert(hash4);
s.insert(hash5);
ASSERT_EQ(5, s.size());
ASSERT_EQ(1, group_info.count(hash0));
ASSERT_EQ(1, group_info.count(hash1));
ASSERT_EQ(1, group_info.count(hash2));
ASSERT_EQ(1, group_info.count(hash3));
ASSERT_EQ(1, group_info.count(hash4));
ASSERT_EQ(1, group_info.count(hash5));
ASSERT_EQ(2, groups[hash0].size());
ASSERT_EQ(2, groups[hash1].size());
ASSERT_EQ(1, groups[hash2].size());
ASSERT_EQ(1, groups[hash3].size());
ASSERT_EQ(1, groups[hash4].size());
ASSERT_EQ(1, groups[hash5].size());
}
TEST_PASS(PackAllReduceReplacePass, DividePacks) {
auto cn = CompNode::load("gpux");
auto graph = ComputingGraph::make();
auto client = std::make_shared<test::MockGroupClient>();
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
ThinHashMap<uint64_t, cg::OprNodeArray> groups;
ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs;
auto insert_opr = [&] (size_t size) {
auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)});
auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
"key", 2, 0, 0, client, mode)[0];
auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1);
return opr;
};
auto pack_size = [&] (cg::OprNodeArray& pack) {
size_t sum = 0;
for (size_t i = 0; i < pack.size(); i++) {
auto var = pack[i]->input(0);
sum += var->dtype().size(var->shape().total_nr_elems());
}
return sum;
};
groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100
groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
groups[0].push_back(insert_opr(500)); // group0, pack1, size=800
groups[0].push_back(insert_opr(200)); // group0, pack1, size=800
groups[0].push_back(insert_opr(100)); // group0, pack1, size=800
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
groups[1].push_back(insert_opr(400)); // group1, pack0, size=900
groups[1].push_back(insert_opr(300)); // group1, pack0, size=900
groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000);
ASSERT_EQ(2, packs.size());
ASSERT_EQ(2, packs[0].size());
ASSERT_EQ(4, packs[0][0].size());
ASSERT_EQ(1100, pack_size(packs[0][0]));
ASSERT_EQ(3, packs[0][1].size());
ASSERT_EQ(800, pack_size(packs[0][1]));
ASSERT_EQ(1, packs[1].size());
ASSERT_EQ(4, packs[1][0].size());
ASSERT_EQ(900, pack_size(packs[1][0]));
}
TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto cn = CompNode::load("gpux");
auto graph = ComputingGraph::make();
auto client = std::make_shared<test::MockGroupClient>();
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
size_t nr_devices = 2;
uint32_t rank = 0;
uint32_t root = 0;
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
ThinHashMap<uint64_t, cg::OprNodeArray> groups;
auto insert_opr = [&] (const TensorShape& shape) {
auto dev = std::make_shared<DeviceTensorND>(cn, shape);
auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
"key", nr_devices, rank, root, client, mode)[0];
auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1);
gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups);
return symvar;
};
auto shape_x = TensorShape{100, 200};
auto shape_y = TensorShape{200, 400};
auto x = insert_opr(shape_x);
auto y = insert_opr(shape_y);
ASSERT_EQ(1, group_info.size());
ASSERT_EQ(1, groups.size());
auto info = group_info.begin()->second;
auto pack = groups.begin()->second;
size_t pack_id = 0;
ThinHashMap<VarNode*, VarNode*> replace_map;
gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1);
auto grad_x = SymbolVar(x.node()->owner_opr()->input(0));
auto grad_y = SymbolVar(y.node()->owner_opr()->input(0));
auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0);
std::string key = ssprintf("grad_pack_%zu", pack_id);
auto allreduce = opr::CollectiveComm::make({concat}, graph.get(),
key, nr_devices, rank, root, client, mode)[0];
std::vector<size_t> partition;
partition.push_back(shape_x.total_nr_elems());
partition.push_back(shape_y.total_nr_elems());
auto splits = opr::Split::make(allreduce,
opr::Split::Options::make_partition(allreduce, 0, partition));
ASSERT_EQ(2, splits.size());
auto dest_x = splits[0].reshape(shape_x);
auto dest_y = splits[1].reshape(shape_y);
ASSERT_EQ(2, replace_map.size());
ASSERT_TRUE(replace_map.count(x.node()) > 0);
ASSERT_EQ(replace_map.at(x.node()), dest_x.node());
ASSERT_TRUE(replace_map.count(y.node()) > 0);
ASSERT_EQ(replace_map.at(y.node()), dest_y.node());
}
TEST_PASS(PackAllReduceReplacePass, Equivalence) {
REQUIRE_GPU(2);
auto cns = load_multiple_xpus(2);
auto client = std::make_shared<test::MockGroupClient>();
auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph,
SymbolVarArray& array) {
HostTensorGenerator<> gen;
auto cn = cns[rank];
auto host_x = gen({1, 1000});
auto host_y = gen({1000, 1});
auto dev_x = std::make_shared<DeviceTensorND>(cn);
auto dev_y = std::make_shared<DeviceTensorND>(cn);
dev_x->copy_from(*host_x).sync();
dev_y->copy_from(*host_y).sync();
auto x = opr::SharedDeviceTensor::make(*graph, dev_x);
auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y);
auto loss = opr::MatrixMul::make(x, y).flatten();
auto grad_x = opr::VirtualGrad::make(loss, x);
auto grad_y = opr::VirtualGrad::make(loss, y);
using Mode = opr::CollectiveComm::Param::Mode;
bool is_root = (rank == 0);
auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(),
"x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(),
"y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
graph->options().allreduce_pack_max_size = 5000;
graph->options().allreduce_pack_ignore_first = 0;
auto dest_vars = gopt::GraphOptimizer{}
.add_pass<gopt::PackAllReduceScanPass>()
.add_pass<gopt::PackAllReduceReplacePass>()
.apply({{reduced_x, reduced_y}}).endpoint_vars();
array.emplace_back(reduced_x);
array.emplace_back(reduced_y);
array.emplace_back(dest_vars[0]);
array.emplace_back(dest_vars[1]);
};
auto run = [&] (uint32_t rank) {
auto graph = ComputingGraph::make();
SymbolVarArray array;
build_graph(rank, graph, array);
HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1;
graph->options().allreduce_pack_max_size = 0;
auto func = graph->compile({make_callback_copy(array[0], host_reduced_x),
make_callback_copy(array[1], host_reduced_y),
make_callback_copy(array[2], host_dest_0),
make_callback_copy(array[3], host_dest_1)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0);
MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1);
};
std::thread t0(run, 0);
std::thread t1(run, 1);
t0.join();
t1.join();
}
#endif // MGB_ENABLE_OPR_MM
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -461,16 +461,7 @@ void CollectiveComm::opr_register() {
m_rank = reg_info.rank;
m_root = reg_info.root_rank;
MegRayCommunicatorBuilder* builder;
{
static std::mutex user_data_mtx;
std::unique_lock<std::mutex> lk(user_data_mtx);
builder = owner_graph()->options().user_data
.get_user_data_or_create<MegRayCommunicatorBuilder>();
}
m_megray_comm = builder->get_megray_comm(
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);
......@@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
return opr::CollectiveComm::make(
auto new_opr = CollectiveComm::make(
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.group_client(), opr.dev_buffers(), opr.param(),
opr.dtype(), opr.backend(), config)[0]
.node()
->owner_opr();
new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash());
return new_opr;
}
MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm);
......
......@@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() {
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false,
comp_node.get_uid());
auto megray_comm_builder =
owner_graph()
->options()
.user_data
.get_user_data_or_create<MegRayCommunicatorBuilder>();
m_megray_comm = megray_comm_builder->get_megray_comm(
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}
......@@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() {
m_peer.key, 2, false, 1,
comp_node.get_uid());
auto megray_comm_builder =
owner_graph()
->options()
.user_data
.get_user_data_or_create<MegRayCommunicatorBuilder>();
m_megray_comm = megray_comm_builder->get_megray_comm(
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}
......
......@@ -14,8 +14,8 @@
using namespace mgb;
using namespace opr;
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_mtx);
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);
if (it != m_megray_comms.end()) {
comm = it->second;
......@@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Comm
return false;
}
void MegRayCommunicatorBuilder::emplace(uint64_t hash,
void MegRayCommBuilder::emplace(uint64_t hash,
std::shared_ptr<MegRay::Communicator> comm) {
std::unique_lock<std::mutex> lk(m_mtx);
std::unique_lock<std::mutex> lk(m_map_mtx);
m_megray_comms.emplace(hash, comm);
}
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm(
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client) {
{
// singleton pattern
std::unique_lock<std::mutex> lk(sm_instance_mtx);
if (sm_instance == nullptr) {
sm_instance = new MegRayCommBuilder();
}
}
std::shared_ptr<MegRay::Communicator> comm;
if (!find(hash, comm)) {
if (!sm_instance->find(hash, comm)) {
comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid();
auto uids = group_client->gather_uid(uid, key, size, rank);
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
emplace(hash, comm);
sm_instance->emplace(hash, comm);
}
return comm;
}
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder);
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr;
std::mutex MegRayCommBuilder::sm_instance_mtx;
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -81,6 +81,10 @@ public:
return m_group_client;
}
void set_pack_hash(uint64_t hash) { m_pack_hash = hash; }
uint64_t pack_hash() const { return m_pack_hash; }
std::shared_ptr<MegRay::Context> megray_ctx() const {
return m_megray_ctx;
}
......@@ -123,6 +127,9 @@ private:
// whose shape infer should be disabled *during* static infer phase.
bool m_enable_shape_infer = false;
//! set in PackAllReduceScanPass and used in PackAllReduceReplacePass
uint64_t m_pack_hash = 0;
std::shared_ptr<MegRay::Context> m_megray_ctx;
std::shared_ptr<MegRay::Communicator> m_megray_comm;
bool m_init = false;
......
......@@ -126,6 +126,8 @@ class GroupClient {
virtual ~GroupClient() = default;
public:
virtual const std::string& get_addr() const = 0;
virtual GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
......
......@@ -23,18 +23,19 @@ namespace opr {
/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/
class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
class MegRayCommBuilder {
private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_mtx;
std::mutex m_map_mtx;
static MegRayCommBuilder* sm_instance;
static std::mutex sm_instance_mtx;
public:
std::shared_ptr<MegRay::Communicator> get_megray_comm(
static std::shared_ptr<MegRay::Communicator> get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client);
......
......@@ -47,7 +47,7 @@ public:
uint32_t group_barrier(uint32_t size, uint32_t rank) override;
const std::string& get_addr() const {
const std::string& get_addr() const override {
return m_addr;
}
......
......@@ -17,11 +17,10 @@
#include "megbrain/opr/utility.h"
#include "megbrain/test/helper.h"
#include "megbrain/graph.h"
#include "mock_client.h"
using namespace mgb;
namespace {
using Mode = opr::CollectiveComm::Param::Mode;
SymbolVar make_all_reduce_output(const Mode mode,
......@@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) {
rdc, opr::Split::Options::make_average(0, inputs.size()));
}
class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);
}
void set_output_shape(const std::string& key,
const TensorShape& shape) override {
m_mgr.set_output_shape(key, shape);
}
TensorShape get_output_shape(const std::string& key) override {
return m_mgr.get_output_shape(key);
}
uint32_t group_barrier(uint32_t size, uint32_t rank) override {
return m_mgr.group_barrier(size, rank);
}
private:
opr::GroupManager m_mgr;
};
} // namespace
TEST(TestOprCollectiveComm, AllReduce) {
REQUIRE_GPU(2);
......@@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
......@@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() {
auto graph0 = ComputingGraph::make();
......@@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) {
HostTensorND host_y0, host_y1, host_y_expect;
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
......@@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) {
HostTensorND host_out_grad0, host_out_grad1;
HostTensorND host_out_grad0_expect, host_out_grad1_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
......@@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) {
auto host_x1 = gen({8});
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) {
HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect;
HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
......@@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) {
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) {
auto host_x0 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
......@@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) {
auto host_x0 = gen({28, 28});
HostTensorND host_y0, host_y1;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......@@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
......
......@@ -14,51 +14,14 @@
#include "megbrain/opr/utility.h"
#include "megbrain/system.h"
#include "megbrain/test/helper.h"
#include "mock_client.h"
#include <thread>
using namespace mgb;
using namespace opr;
namespace {
class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);
}
void set_output_shape(const std::string& key,
const TensorShape& shape) override {
m_mgr.set_output_shape(key, shape);
}
TensorShape get_output_shape(const std::string& key) override {
return m_mgr.get_output_shape(key);
}
uint32_t group_barrier(uint32_t size, uint32_t rank) override {
return m_mgr.group_barrier(size, rank);
}
private:
opr::GroupManager m_mgr;
};
const auto send_tag = RemoteIOBase::Type::SEND;
const auto recv_tag = RemoteIOBase::Type::RECV;
} // anonymous namespace
const auto send_tag = opr::RemoteIOBase::Type::SEND;
const auto recv_tag = opr::RemoteIOBase::Type::RECV;
TEST(TestOprIORemote, Identity) {
REQUIRE_GPU(2);
......@@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) {
auto host_x = gen({28, 28});
HostTensorND host_y;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
......@@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3}, cns[1]);
HostTensorND host_x_get;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto sender = [&]() {
auto graph = ComputingGraph::make();
......@@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3}, cns[0]);
HostTensorND host_x_get;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto sender = [&]() {
sys::set_thread_name("sender");
......@@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) {
HostTensorGenerator<> gen;
auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]);
HostTensorND host_z;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto sender = [&]() {
auto graph = ComputingGraph::make();
......@@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3}, cns[0]);
HostTensorND host_gx, host_loss;
auto client = std::make_shared<MockGroupClient>();
auto client = std::make_shared<test::MockGroupClient>();
auto sender = [&]() {
sys::set_thread_name("sender");
......
/**
* \file src/opr-mm/test/mock_client.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/group_manager.h"
namespace mgb {
namespace test {
class MockGroupClient final : public opr::GroupClient {
public:
using RegisterInfo = opr::GroupManager::RegisterInfo;
MockGroupClient(const std::string& server_addr = "mock_addr") :
m_addr(server_addr) {
}
~MockGroupClient() override = default;
const std::string& get_addr() const {
return m_addr;
}
RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);
}
void set_output_shape(const std::string& key,
const TensorShape& shape) override {
m_mgr.set_output_shape(key, shape);
}
TensorShape get_output_shape(const std::string& key) override {
return m_mgr.get_output_shape(key);
}
uint32_t group_barrier(uint32_t size, uint32_t rank) override {
return m_mgr.group_barrier(size, rank);
}
private:
const std::string m_addr;
opr::GroupManager m_mgr;
};
} // namespace test
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册