提交 064c774f 编写于 作者: M Megvii Engine Team

feat(imperative): impl hashable for SendRecv and add virtual input for Recv

GitOrigin-RevId: 5e8c27ac81b844bbf8720a108da7ad208060dad2
上级 798f7b3e
......@@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, *graph, group_client, OperatorNodeConfig{recv.cn},
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn},
recv.shape, recv.dtype));
}
......@@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
} // anonymous namespace
#endif // MGB_ENABLE_OPR_MM
bool RemoteSend::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteSend>().as_tuple();
}
size_t RemoteSend::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_to);
return xxhash.digest();
}
bool RemoteRecv::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple();
}
size_t RemoteRecv::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_from);
append(cn.to_string());
append(dtype.handle());
append(shape.to_string());
return xxhash.digest();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
......
......@@ -31,6 +31,13 @@ public:
std::string addr;
uint32_t port;
uint32_t rank_to;
size_t hash() const override;
bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, addr, port, rank_to);
}
};
class RemoteRecv : public OpDefImplBase<RemoteRecv> {
......@@ -55,6 +62,13 @@ public:
CompNode cn;
TensorShape shape;
DType dtype;
size_t hash() const override;
bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string());
}
};
} // namespace imperative
......
......@@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
add_equivalence_component<ScalarHash<void*>>(this);
}
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) :
Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) {
m_key = key;
m_group_client = group_client;
add_input({var});
add_output(None)
->dtype(dtype)
.add_flag(VarNode::Flag::NO_MEM_RECLAIM)
.add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
......@@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
return opr->output(0);
}
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
key, var.node(), graph, group_client, config, shape, dtype));
return opr->output(0);
}
void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();
......
......@@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
static SymbolVar make(
const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
static SymbolVar make(
const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
private:
const TensorShape m_shape;
const DType m_dtype;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册